Skip to content

Commit 702f414

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Lazy load metadata to save disk access when user doesn't use metadata
PiperOrigin-RevId: 694450576
1 parent e14764e commit 702f414

File tree

2 files changed

+107
-10
lines changed

2 files changed

+107
-10
lines changed

tensorflow_datasets/core/dataset_info.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -747,12 +747,16 @@ def read_from_directory(self, dataset_info_dir: epath.PathLike) -> None:
747747
dataset_info_dir
748748
)
749749

750-
# Restore the MetaDataDict from metadata.json if there is any
751-
if _metadata_filepath(dataset_info_dir).exists():
752-
# If the dataset was loaded from file, self.metadata will be `None`, so
753-
# we create a MetadataDict first.
754-
if self._metadata is None:
755-
self._metadata = MetadataDict()
750+
# If the dataset was loaded from file, self.metadata will be `None`, so
751+
# we create a MetadataDict first.
752+
if self._metadata is None:
753+
self._metadata = LazyMetadataDict(dataset_info_dir)
754+
elif isinstance(self._metadata, MetadataDict):
755+
lazy_metadata = LazyMetadataDict(dataset_info_dir)
756+
lazy_metadata.update(self._metadata)
757+
self._metadata = lazy_metadata
758+
elif _metadata_filepath(dataset_info_dir).exists():
759+
# Restore the MetaDataDict from metadata.json if there is any
756760
self._metadata.load_metadata(dataset_info_dir)
757761

758762
# Update fields which are not defined in the code. This means that
@@ -1375,6 +1379,12 @@ def add_tfds_data_source_access(
13751379
)
13761380

13771381

1382+
def _load_metadata_from_file(data_dir: epath.PathLike) -> dict[str, Any]:
1383+
"""Loads metadata from file."""
1384+
with _metadata_filepath(data_dir).open(mode="r") as f:
1385+
return json.load(f)
1386+
1387+
13781388
class MetadataDict(Metadata, dict):
13791389
"""A `tfds.core.Metadata` object that acts as a `dict`.
13801390
@@ -1389,8 +1399,41 @@ def save_metadata(self, data_dir):
13891399
def load_metadata(self, data_dir):
13901400
"""Restore the metadata."""
13911401
self.clear()
1392-
with _metadata_filepath(data_dir).open(mode="r") as f:
1393-
self.update(json.load(f))
1402+
self.update(_load_metadata_from_file(data_dir))
1403+
1404+
1405+
class LazyMetadataDict(MetadataDict):
1406+
"""A `tfds.core.Metadata` object that acts as a `dict`.
1407+
1408+
Content is lazily loaded from the given data directory.
1409+
"""
1410+
1411+
def __init__(self, data_dir: epath.PathLike) -> None:
1412+
self._data_dir = epath.Path(data_dir)
1413+
self._data_is_loaded = False
1414+
super().__init__()
1415+
1416+
def _load_metadata(self):
1417+
if not self._data_is_loaded:
1418+
if _metadata_filepath(self._data_dir).exists():
1419+
self.load_metadata(self._data_dir)
1420+
self._data_is_loaded = True
1421+
1422+
def __getitem__(self, key, /):
1423+
self._load_metadata()
1424+
return super().__getitem__(key)
1425+
1426+
def __eq__(self, value, /):
1427+
self._load_metadata()
1428+
return super().__eq__(value)
1429+
1430+
def keys(self):
1431+
self._load_metadata()
1432+
return super().keys()
1433+
1434+
def items(self):
1435+
self._load_metadata()
1436+
return super().items()
13941437

13951438

13961439
class BeamMetadataDict(MetadataDict):

tensorflow_datasets/core/dataset_info_test.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Tests for tensorflow_datasets.core.dataset_info."""
17-
1816
import json
1917
import os
2018
import pathlib
2119
import re
2220
import tempfile
2321
import time
2422
from typing import Union
23+
from unittest import mock
2524

25+
from etils import epath
2626
import numpy as np
2727
import pytest
2828
import tensorflow as tf
@@ -836,5 +836,59 @@ def test_get_split_info_from_proto_unavailable_format(self):
836836
# pylint: enable=g-inconsistent-quotes
837837

838838

839+
class LazyMetadataDictTest(testing.TestCase):
840+
841+
def setUp(self):
842+
super().setUp()
843+
self.data_dir = epath.Path(self.tmp_dir)
844+
metadata_file = self.data_dir / "metadata.json"
845+
metadata_file.write_text(json.dumps({"test": "test123"}))
846+
847+
def test_load_metadata(self):
848+
metadata_dict = dataset_info.LazyMetadataDict(self.data_dir)
849+
self.assertEqual(metadata_dict["test"], "test123")
850+
851+
@mock.patch.object(dataset_info, "_load_metadata_from_file")
852+
def test_get_item(self, mock_load_metadata_from_file):
853+
mock_load_metadata_from_file.return_value = {"test": "test123"}
854+
855+
metadata_dict = dataset_info.LazyMetadataDict(self.data_dir)
856+
mock_load_metadata_from_file.assert_not_called()
857+
858+
self.assertEqual(metadata_dict["test"], "test123")
859+
mock_load_metadata_from_file.assert_called_with(self.data_dir)
860+
861+
self.assertEqual(sorted(metadata_dict.keys()), ["test"])
862+
863+
# Update the metadata.
864+
metadata_dict["abc"] = "def"
865+
self.assertEqual(metadata_dict["abc"], "def")
866+
867+
@mock.patch.object(dataset_info, "_load_metadata_from_file")
868+
def test_keys(self, mock_load_metadata_from_file):
869+
mock_load_metadata_from_file.return_value = {"test": "test123"}
870+
metadata_dict = dataset_info.LazyMetadataDict(self.data_dir)
871+
mock_load_metadata_from_file.assert_not_called()
872+
actual_keys = sorted(metadata_dict.keys())
873+
mock_load_metadata_from_file.assert_called_with(self.data_dir)
874+
self.assertEqual(actual_keys, ["test"])
875+
876+
@mock.patch.object(dataset_info, "_load_metadata_from_file")
877+
def test_items(self, mock_load_metadata_from_file):
878+
mock_load_metadata_from_file.return_value = {"test": "test123"}
879+
metadata_dict = dataset_info.LazyMetadataDict(self.data_dir)
880+
mock_load_metadata_from_file.assert_not_called()
881+
self.assertNotEmpty(metadata_dict.items())
882+
mock_load_metadata_from_file.assert_called_with(self.data_dir)
883+
884+
@mock.patch.object(dataset_info, "_load_metadata_from_file")
885+
def test_updating_before_loading(self, mock_load_metadata_from_file):
886+
metadata_dict = dataset_info.LazyMetadataDict(self.data_dir)
887+
mock_load_metadata_from_file.assert_not_called()
888+
889+
metadata_dict.update({"abc": "def", "ghi": "jkl"})
890+
mock_load_metadata_from_file.assert_not_called()
891+
892+
839893
if __name__ == "__main__":
840894
testing.test_main()

0 commit comments

Comments
 (0)