|
18 | 18 | import json
|
19 | 19 | import os
|
20 | 20 | import pathlib
|
| 21 | +import re |
21 | 22 | import tempfile
|
22 | 23 | import time
|
23 | 24 | from typing import Union
|
@@ -698,6 +699,81 @@ def test_supports_file_format():
|
698 | 699 | )
|
699 | 700 |
|
700 | 701 |
|
| 702 | +class GetSplitInfoFromProtoTest(testing.TestCase): |
| 703 | + |
| 704 | + def _dataset_info_proto_with_splits(self): |
| 705 | + return dataset_info_pb2.DatasetInfo( |
| 706 | + name="dataset", |
| 707 | + file_format="tfrecord", |
| 708 | + alternative_file_formats=["riegeli"], |
| 709 | + splits=[ |
| 710 | + dataset_info_pb2.SplitInfo( |
| 711 | + name="train", |
| 712 | + shard_lengths=[1, 2, 3], |
| 713 | + num_bytes=42, |
| 714 | + ), |
| 715 | + dataset_info_pb2.SplitInfo( |
| 716 | + name="test", |
| 717 | + shard_lengths=[1, 2, 3], |
| 718 | + num_bytes=42, |
| 719 | + filepath_template="{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}", |
| 720 | + ), |
| 721 | + ], |
| 722 | + ) |
| 723 | + |
| 724 | + def test_get_split_info_from_proto_undefined_filename_template(self): |
| 725 | + actual = dataset_info.get_split_info_from_proto( |
| 726 | + dataset_info_proto=self._dataset_info_proto_with_splits(), |
| 727 | + split_name="train", |
| 728 | + data_dir="/path/to/data", |
| 729 | + file_format=file_adapters.FileFormat.TFRECORD, |
| 730 | + ) |
| 731 | + assert actual.name == "train" |
| 732 | + assert actual.shard_lengths == [1, 2, 3] |
| 733 | + assert actual.num_bytes == 42 |
| 734 | + assert actual.filename_template.dataset_name == "dataset" |
| 735 | + assert actual.filename_template.template == naming.DEFAULT_FILENAME_TEMPLATE |
| 736 | + |
| 737 | + def test_get_split_info_from_proto_defined_filename_template(self): |
| 738 | + actual = dataset_info.get_split_info_from_proto( |
| 739 | + dataset_info_proto=self._dataset_info_proto_with_splits(), |
| 740 | + split_name="test", |
| 741 | + data_dir="/path/to/data", |
| 742 | + file_format=file_adapters.FileFormat.TFRECORD, |
| 743 | + ) |
| 744 | + assert actual.name == "test" |
| 745 | + assert actual.shard_lengths == [1, 2, 3] |
| 746 | + assert actual.filename_template.dataset_name == "dataset" |
| 747 | + assert ( |
| 748 | + actual.filename_template.template |
| 749 | + == "{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}" |
| 750 | + ) |
| 751 | + |
| 752 | + def test_get_split_info_from_proto_non_existing_split(self): |
| 753 | + actual = dataset_info.get_split_info_from_proto( |
| 754 | + dataset_info_proto=self._dataset_info_proto_with_splits(), |
| 755 | + split_name="undefined", |
| 756 | + data_dir="/path/to/data", |
| 757 | + file_format=file_adapters.FileFormat.TFRECORD, |
| 758 | + ) |
| 759 | + assert actual is None |
| 760 | + |
| 761 | + def test_get_split_info_from_proto_unavailable_format(self): |
| 762 | + with pytest.raises( |
| 763 | + ValueError, |
| 764 | + match=re.escape( |
| 765 | + "File format parquet does not match available dataset file formats:" |
| 766 | + " ['riegeli', 'tfrecord']." |
| 767 | + ), |
| 768 | + ): |
| 769 | + dataset_info.get_split_info_from_proto( |
| 770 | + dataset_info_proto=self._dataset_info_proto_with_splits(), |
| 771 | + split_name="undefined", |
| 772 | + data_dir="/path/to/data", |
| 773 | + file_format=file_adapters.FileFormat.PARQUET, |
| 774 | + ) |
| 775 | + |
| 776 | + |
701 | 777 | # pylint: disable=g-inconsistent-quotes
|
702 | 778 | _INFO_STR = '''tfds.core.DatasetInfo(
|
703 | 779 | name='mnist',
|
|
0 commit comments