Skip to content

Commit 0c2acb3

Browse files
marcenacpThe TensorFlow Datasets Authors
authored andcommitted
Retrieve description/citation metadata from 1) Hugging Face Hub (allegedly source of truth among the Hugging Face API), then 2) Hugging Face Datasets.
PiperOrigin-RevId: 643005946
1 parent fa38a9d commit 0c2acb3

File tree

4 files changed

+55
-27
lines changed

4 files changed

+55
-27
lines changed

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@
4242
from tensorflow_datasets.core import example_serializer
4343
from tensorflow_datasets.core import features as feature_lib
4444
from tensorflow_datasets.core import file_adapters
45-
from tensorflow_datasets.core import lazy_imports_lib
4645
from tensorflow_datasets.core import split_builder as split_builder_lib
4746
from tensorflow_datasets.core import splits as splits_lib
4847
from tensorflow_datasets.core.utils import huggingface_utils
4948
from tensorflow_datasets.core.utils import shard_utils
5049
from tensorflow_datasets.core.utils import tqdm_utils
5150
from tensorflow_datasets.core.utils import version as version_lib
5251
from tensorflow_datasets.core.utils.lazy_imports_utils import datasets as hf_datasets
52+
from tensorflow_datasets.core.utils.lazy_imports_utils import huggingface_hub
5353

5454

5555
def _extract_supervised_keys(hf_info):
@@ -224,21 +224,21 @@ def __init__(
224224
or '1.0.0'
225225
)
226226
self.VERSION = version_lib.Version(version) # pylint: disable=invalid-name
227-
if self._hf_config:
228-
self._converted_builder_config = dataset_builder.BuilderConfig(
229-
name=tfds_config,
230-
version=self.VERSION,
231-
description=self._hf_info.description,
232-
)
233-
else:
234-
self._converted_builder_config = None
235227
self.name = huggingface_utils.convert_hf_name(hf_repo_id)
236228
self._hf_hub_token = hf_hub_token
237229
self._hf_num_proc = hf_num_proc
238230
self._tfds_num_proc = tfds_num_proc
239231
self._verification_mode = (
240232
'no_checks' if ignore_verifications else 'all_checks'
241233
)
234+
if self._hf_config:
235+
self._converted_builder_config = dataset_builder.BuilderConfig(
236+
name=tfds_config,
237+
version=self.VERSION,
238+
description=self._get_text_field('description'),
239+
)
240+
else:
241+
self._converted_builder_config = None
242242
super().__init__(
243243
file_format=file_format, config=tfds_config, data_dir=data_dir
244244
)
@@ -266,8 +266,16 @@ def _hf_download_and_prepare(self):
266266

267267
@property
268268
def _hf_info(self) -> hf_datasets.DatasetInfo:
269+
"""Retrieves the dataset info from the HuggingFace Datasets."""
269270
return self._hf_builder.info
270271

272+
@functools.cached_property
273+
def _hf_hub_info(self) -> huggingface_hub.hf_api.DatasetInfo:
274+
"""Retrieves the dataset info from the HuggingFace Hub and caches it."""
275+
return huggingface_hub.dataset_info(
276+
self._hf_repo_id, token=self._hf_hub_token
277+
)
278+
271279
def _hf_features(self) -> hf_datasets.Features:
272280
if not self._hf_info.features:
273281
# We need to download and prepare the data to know its features.
@@ -278,9 +286,9 @@ def _hf_features(self) -> hf_datasets.Features:
278286
def _info(self) -> dataset_info_lib.DatasetInfo:
279287
return dataset_info_lib.DatasetInfo(
280288
builder=self,
281-
description=self._hf_info.description,
289+
description=self._get_text_field('description'),
282290
features=huggingface_utils.convert_hf_features(self._hf_features()),
283-
citation=self._hf_info.citation,
291+
citation=self._get_text_field('citation'),
284292
license=self._get_license(),
285293
supervised_keys=_extract_supervised_keys(self._hf_info),
286294
)
@@ -417,24 +425,32 @@ def _write_shards(
417425

418426
def _get_license(self) -> str | None:
419427
"""Implements heuristics to get the license from HuggingFace."""
420-
# First heuristic: check the DatasetInfo from Hugging Face datasets.
421-
if self._hf_info.license:
422-
return self._hf_info.license
423-
huggingface_hub = lazy_imports_lib.lazy_imports.huggingface_hub
424-
# Retrieve the dataset info from the HuggingFace Hub.
425-
repo_id, token = self._hf_repo_id, self._hf_hub_token
426-
dataset_info = huggingface_hub.dataset_info(repo_id, token=token)
427-
# Second heuristic: check the card data.
428+
# Heuristic #1: check the DatasetInfo from Hugging Face Hub/Datasets.
429+
if info_license := self._get_text_field('license'):
430+
return info_license
431+
dataset_info = self._hf_hub_info
432+
# Heuristic #2: check the card data.
428433
if dataset_info.card_data:
429434
if card_data_license := dataset_info.card_data.get('license'):
430435
return card_data_license
431-
# Third heuristic: check the tags.
436+
# Heuristic #3: check the tags.
432437
if dataset_info.tags:
433438
for tag in dataset_info.tags:
434439
if tag.startswith('license:'):
435440
return tag.removeprefix('license:')
436441
return None
437442

443+
def _get_text_field(self, field: str) -> str | None:
444+
"""Get the field from either HF Hub or HF Datasets."""
445+
# The information retrieved from the Hub has priority over the one in the
446+
# builder, because the Hub which is allegedly the new source of truth.
447+
for dataset_info in [self._hf_hub_info, self._hf_info]:
448+
# `description` and `citation` are not official fields in the Hugging Face
449+
# Hub API but they're still exposed in its __dict__.
450+
if value := getattr(dataset_info, field, None):
451+
return value
452+
return None
453+
438454

439455
def builder(
440456
name: str, config: Optional[str] = None, **builder_kwargs
@@ -449,5 +465,4 @@ def login_to_hf(hf_hub_token: Optional[str] = None):
449465
"""Logs in to Hugging Face Hub with the token as arg or env variable."""
450466
hf_hub_token = hf_hub_token or os.environ.get('HUGGING_FACE_HUB_TOKEN')
451467
if hf_hub_token is not None:
452-
huggingface_hub = lazy_imports_lib.lazy_imports.huggingface_hub
453468
huggingface_hub.login(token=hf_hub_token)

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121
from tensorflow_datasets.core import lazy_imports_lib
2222
from tensorflow_datasets.core.dataset_builders import huggingface_dataset_builder
23+
from tensorflow_datasets.core.utils.lazy_imports_utils import huggingface_hub
2324

2425
PIL_Image = lazy_imports_lib.lazy_imports.PIL_Image
2526

@@ -72,6 +73,22 @@ def mock_login_to_hf():
7273
yield login_to_hf
7374

7475

76+
@pytest.fixture(autouse=True)
77+
def mock_hub_dataset_info():
78+
fake_dataset_info = huggingface_hub.hf_api.DatasetInfo(
79+
id='foo/bar',
80+
citation='citation from the hub',
81+
private=False,
82+
downloads=123,
83+
likes=456,
84+
tags=[],
85+
)
86+
with mock.patch.object(
87+
huggingface_hub, 'dataset_info', return_value=fake_dataset_info
88+
) as dataset_info:
89+
yield dataset_info
90+
91+
7592
@pytest.fixture(name='builder')
7693
def mock_huggingface_dataset_builder(
7794
tmp_path, load_dataset_builder, login_to_hf
@@ -91,7 +108,7 @@ def mock_huggingface_dataset_builder(
91108
)
92109
login_to_hf.assert_called_once_with('SECRET_TOKEN')
93110
assert builder.info.description == 'description'
94-
assert builder.info.citation == 'citation'
111+
assert builder.info.citation == 'citation from the hub'
95112
assert builder.info.redistribution_info.license == 'test-license'
96113
yield builder
97114

tensorflow_datasets/core/lazy_imports_lib.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,6 @@ def zarr(cls):
229229
def conllu(cls):
230230
return _try_import("conllu")
231231

232-
@utils.classproperty
233-
@classmethod
234-
def huggingface_hub(cls):
235-
return _try_import("huggingface_hub")
236-
237232

238233
lazy_imports = LazyImporter # pylint: disable=invalid-name
239234

tensorflow_datasets/core/utils/lazy_imports_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def datasets_error_callback(module_name: Exception):
129129

130130
with epy.lazy_imports(error_callback=datasets_error_callback):
131131
import datasets # pytype: disable=import-error
132+
import huggingface_hub # pytype: disable=import-error
132133

133134
with epy.lazy_imports(error_callback=array_record_error_callback):
134135
from array_record.python import array_record_data_source

0 commit comments

Comments
 (0)