Skip to content

Commit 5de600f

Browse files
authored
Merge branch 'tensorflow:master' into master
2 parents f3bdd12 + 30a1ad0 commit 5de600f

16 files changed

+307
-116
lines changed

docs/_index.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ landing_page:
4141
{% dynamic endif %}
4242
- classname: devsite-landing-row-cards
4343
items:
44-
- heading: "Explore datasets with Know Your Data"
45-
image_path: /resources/images/kyd-screenshot.jpg
46-
buttons:
47-
- label: Go to Know Your Data
48-
path: https://knowyourdata.withgoogle.com
4944
- heading: Introducing TensorFlow Datasets
5045
image_path: /resources/images/tf-logo-card-16x9.png
5146
path: https://blog.tensorflow.org/2019/02/introducing-tensorflow-datasets.html

tensorflow_datasets/core/dataset_builder.py

Lines changed: 90 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -29,37 +29,42 @@
2929
from typing import Any, ClassVar, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union
3030

3131
from absl import logging
32-
from etils import epath
33-
import importlib_resources
34-
from tensorflow_datasets.core import constants
35-
from tensorflow_datasets.core import dataset_info
36-
from tensorflow_datasets.core import dataset_metadata
37-
from tensorflow_datasets.core import decode
38-
from tensorflow_datasets.core import download
39-
from tensorflow_datasets.core import file_adapters
40-
from tensorflow_datasets.core import lazy_imports_lib
41-
from tensorflow_datasets.core import logging as tfds_logging
42-
from tensorflow_datasets.core import naming
43-
from tensorflow_datasets.core import reader as reader_lib
44-
from tensorflow_datasets.core import registered
45-
from tensorflow_datasets.core import split_builder as split_builder_lib
46-
from tensorflow_datasets.core import splits as splits_lib
47-
from tensorflow_datasets.core import tf_compat
48-
from tensorflow_datasets.core import units
49-
from tensorflow_datasets.core import utils
50-
from tensorflow_datasets.core import writer as writer_lib
51-
from tensorflow_datasets.core.data_sources import array_record
52-
from tensorflow_datasets.core.data_sources import parquet
53-
from tensorflow_datasets.core.proto import dataset_info_pb2
54-
from tensorflow_datasets.core.utils import file_utils
55-
from tensorflow_datasets.core.utils import gcs_utils
56-
from tensorflow_datasets.core.utils import read_config as read_config_lib
57-
from tensorflow_datasets.core.utils import type_utils
32+
from etils import epy
5833
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
5934
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
6035
from tensorflow_datasets.core.utils.lazy_imports_utils import tree
61-
import termcolor
6236

37+
with epy.lazy_imports():
38+
# pylint: disable=g-import-not-at-top
39+
from etils import epath
40+
import importlib_resources
41+
import termcolor
42+
43+
from tensorflow_datasets.core import constants
44+
from tensorflow_datasets.core import dataset_info
45+
from tensorflow_datasets.core import dataset_metadata
46+
from tensorflow_datasets.core import decode
47+
from tensorflow_datasets.core import download
48+
from tensorflow_datasets.core import file_adapters
49+
from tensorflow_datasets.core import lazy_imports_lib
50+
from tensorflow_datasets.core import logging as tfds_logging
51+
from tensorflow_datasets.core import naming
52+
from tensorflow_datasets.core import reader as reader_lib
53+
from tensorflow_datasets.core import registered
54+
from tensorflow_datasets.core import split_builder as split_builder_lib
55+
from tensorflow_datasets.core import splits as splits_lib
56+
from tensorflow_datasets.core import tf_compat
57+
from tensorflow_datasets.core import units
58+
from tensorflow_datasets.core import utils
59+
from tensorflow_datasets.core import writer as writer_lib
60+
from tensorflow_datasets.core.data_sources import array_record
61+
from tensorflow_datasets.core.data_sources import parquet
62+
from tensorflow_datasets.core.proto import dataset_info_pb2
63+
from tensorflow_datasets.core.utils import file_utils
64+
from tensorflow_datasets.core.utils import gcs_utils
65+
from tensorflow_datasets.core.utils import read_config as read_config_lib
66+
from tensorflow_datasets.core.utils import type_utils
67+
# pylint: enable=g-import-not-at-top
6368

6469
ListOrTreeOrElem = type_utils.ListOrTreeOrElem
6570
Tree = type_utils.Tree
@@ -726,6 +731,17 @@ def download_and_prepare(
726731

727732
self._log_download_done()
728733

734+
# Execute post download and prepare hook if it exists.
735+
self._post_download_and_prepare_hook()
736+
737+
738+
def _post_download_and_prepare_hook(self) -> None:
739+
"""Hook to be executed after download and prepare.
740+
741+
Override this in custom dataset builders to execute custom logic after
742+
download and prepare.
743+
"""
744+
pass
729745

730746
def _update_dataset_info(self) -> None:
731747
"""Updates the `dataset_info.json` file in the dataset dir."""
@@ -767,33 +783,56 @@ def as_data_source(
767783
if split is None:
768784
split = {s: s for s in self.info.splits}
769785

770-
# Create a dataset for each of the given splits
771-
def build_single_data_source(
772-
split: str,
773-
) -> Sequence[Any]:
774-
file_format = self.info.file_format
775-
if file_format == file_adapters.FileFormat.ARRAY_RECORD:
776-
return array_record.ArrayRecordDataSource(
777-
self.info,
778-
split=split,
779-
decoders=decoders,
786+
info = self.info
787+
788+
random_access_formats = file_adapters.FileFormat.with_random_access()
789+
random_access_formats_msg = " or ".join(
790+
[f.value for f in random_access_formats]
791+
)
792+
unsupported_format_msg = (
793+
f"Random access data source for file format {info.file_format} is"
794+
" not supported. Can you try to run download_and_prepare with"
795+
f" file_format set to one of: {random_access_formats_msg}?"
796+
)
797+
798+
if info.file_format is None and not info.alternative_file_formats:
799+
raise ValueError(
800+
"Dataset info file format is not set! For random access, one of the"
801+
f" following formats is required: {random_access_formats_msg}"
802+
)
803+
804+
if (
805+
info.file_format is None
806+
or info.file_format not in random_access_formats
807+
):
808+
available_formats = set(info.alternative_file_formats)
809+
suitable_formats = available_formats.intersection(random_access_formats)
810+
if suitable_formats:
811+
chosen_format = suitable_formats.pop()
812+
logging.info(
813+
"Found random access formats: %s. Chose to use %s. Overriding file"
814+
" format in the dataset info.",
815+
", ".join([f.name for f in suitable_formats]),
816+
chosen_format,
780817
)
781-
elif file_format == file_adapters.FileFormat.PARQUET:
782-
return parquet.ParquetDataSource(
783-
self.info,
784-
split=split,
785-
decoders=decoders,
818+
# Change the dataset info to read from a random access format.
819+
info.set_file_format(
820+
chosen_format, override=True, override_if_initialized=True
786821
)
787822
else:
788-
args = [
789-
f"`file_format='{file_format.value}'`"
790-
for file_format in file_adapters.FileFormat.with_random_access()
791-
]
792-
raise NotImplementedError(
793-
f"Random access data source for file format {file_format} is not"
794-
" supported. Can you try to run download_and_prepare with"
795-
f" {' or '.join(args)}?"
796-
)
823+
raise NotImplementedError(unsupported_format_msg)
824+
825+
# Create a dataset for each of the given splits
826+
def build_single_data_source(split: str) -> Sequence[Any]:
827+
match info.file_format:
828+
case file_adapters.FileFormat.ARRAY_RECORD:
829+
return array_record.ArrayRecordDataSource(
830+
info, split=split, decoders=decoders
831+
)
832+
case file_adapters.FileFormat.PARQUET:
833+
return parquet.ParquetDataSource(info, split=split, decoders=decoders)
834+
case _:
835+
raise NotImplementedError(unsupported_format_msg)
797836

798837
all_ds = tree.map_structure(build_single_data_source, split)
799838
return all_ds

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,32 @@ def test_load_as_data_source(self):
578578
assert len(data_source) == 10
579579
assert data_source[0]["x"] == 28
580580

581+
def test_load_as_data_source_alternative_file_format(self):
582+
data_dir = self.get_temp_dir()
583+
builder = DummyDatasetWithConfigs(
584+
data_dir=data_dir,
585+
config="plus1",
586+
file_format=file_adapters.FileFormat.ARRAY_RECORD,
587+
)
588+
builder.download_and_prepare()
589+
# Change the default file format and add alternative file format.
590+
builder.info.as_proto.file_format = "tfrecord"
591+
builder.info.add_alternative_file_format("array_record")
592+
593+
data_source = builder.as_data_source()
594+
assert isinstance(data_source, dict)
595+
assert isinstance(data_source["train"], array_record.ArrayRecordDataSource)
596+
assert isinstance(data_source["test"], array_record.ArrayRecordDataSource)
597+
assert len(data_source["test"]) == 10
598+
assert data_source["test"][0]["x"] == 28
599+
assert len(data_source["train"]) == 20
600+
assert data_source["train"][0]["x"] == 7
601+
602+
data_source = builder.as_data_source(split="test")
603+
assert isinstance(data_source, array_record.ArrayRecordDataSource)
604+
assert len(data_source) == 10
605+
assert data_source[0]["x"] == 28
606+
581607
@parameterized.named_parameters(
582608
*[
583609
{"file_format": file_format, "testcase_name": file_format.value}

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@
4242
from tensorflow_datasets.core import example_serializer
4343
from tensorflow_datasets.core import features as feature_lib
4444
from tensorflow_datasets.core import file_adapters
45-
from tensorflow_datasets.core import lazy_imports_lib
4645
from tensorflow_datasets.core import split_builder as split_builder_lib
4746
from tensorflow_datasets.core import splits as splits_lib
4847
from tensorflow_datasets.core.utils import huggingface_utils
4948
from tensorflow_datasets.core.utils import shard_utils
5049
from tensorflow_datasets.core.utils import tqdm_utils
5150
from tensorflow_datasets.core.utils import version as version_lib
5251
from tensorflow_datasets.core.utils.lazy_imports_utils import datasets as hf_datasets
52+
from tensorflow_datasets.core.utils.lazy_imports_utils import huggingface_hub
5353

5454

5555
def _extract_supervised_keys(hf_info):
@@ -198,6 +198,7 @@ def __init__(
198198
hf_num_proc: Optional[int] = None,
199199
tfds_num_proc: Optional[int] = None,
200200
ignore_hf_errors: bool = False,
201+
overwrite_version: str | None = None,
201202
**config_kwargs,
202203
):
203204
self._hf_repo_id = hf_repo_id
@@ -216,23 +217,28 @@ def __init__(
216217
f' hf_repo_id={self._hf_repo_id}, hf_config={self._hf_config},'
217218
f' config_kwargs={self.config_kwargs}'
218219
) from e
219-
version = str(self._hf_info.version or self._hf_builder.VERSION or '1.0.0')
220+
version = str(
221+
overwrite_version
222+
or self._hf_info.version
223+
or self._hf_builder.VERSION
224+
or '1.0.0'
225+
)
220226
self.VERSION = version_lib.Version(version) # pylint: disable=invalid-name
221-
if self._hf_config:
222-
self._converted_builder_config = dataset_builder.BuilderConfig(
223-
name=tfds_config,
224-
version=self.VERSION,
225-
description=self._hf_info.description,
226-
)
227-
else:
228-
self._converted_builder_config = None
229227
self.name = huggingface_utils.convert_hf_name(hf_repo_id)
230228
self._hf_hub_token = hf_hub_token
231229
self._hf_num_proc = hf_num_proc
232230
self._tfds_num_proc = tfds_num_proc
233231
self._verification_mode = (
234232
'no_checks' if ignore_verifications else 'all_checks'
235233
)
234+
if self._hf_config:
235+
self._converted_builder_config = dataset_builder.BuilderConfig(
236+
name=tfds_config,
237+
version=self.VERSION,
238+
description=self._get_text_field('description'),
239+
)
240+
else:
241+
self._converted_builder_config = None
236242
super().__init__(
237243
file_format=file_format, config=tfds_config, data_dir=data_dir
238244
)
@@ -260,8 +266,16 @@ def _hf_download_and_prepare(self):
260266

261267
@property
262268
def _hf_info(self) -> hf_datasets.DatasetInfo:
269+
"""Retrieves the dataset info from the HuggingFace Datasets."""
263270
return self._hf_builder.info
264271

272+
@functools.cached_property
273+
def _hf_hub_info(self) -> huggingface_hub.hf_api.DatasetInfo:
274+
"""Retrieves the dataset info from the HuggingFace Hub and caches it."""
275+
return huggingface_hub.dataset_info(
276+
self._hf_repo_id, token=self._hf_hub_token
277+
)
278+
265279
def _hf_features(self) -> hf_datasets.Features:
266280
if not self._hf_info.features:
267281
# We need to download and prepare the data to know its features.
@@ -272,9 +286,9 @@ def _hf_features(self) -> hf_datasets.Features:
272286
def _info(self) -> dataset_info_lib.DatasetInfo:
273287
return dataset_info_lib.DatasetInfo(
274288
builder=self,
275-
description=self._hf_info.description,
289+
description=self._get_text_field('description'),
276290
features=huggingface_utils.convert_hf_features(self._hf_features()),
277-
citation=self._hf_info.citation,
291+
citation=self._get_text_field('citation'),
278292
license=self._get_license(),
279293
supervised_keys=_extract_supervised_keys(self._hf_info),
280294
)
@@ -411,24 +425,32 @@ def _write_shards(
411425

412426
def _get_license(self) -> str | None:
413427
"""Implements heuristics to get the license from HuggingFace."""
414-
# First heuristic: check the DatasetInfo from Hugging Face datasets.
415-
if self._hf_info.license:
416-
return self._hf_info.license
417-
huggingface_hub = lazy_imports_lib.lazy_imports.huggingface_hub
418-
# Retrieve the dataset info from the HuggingFace Hub.
419-
repo_id, token = self._hf_repo_id, self._hf_hub_token
420-
dataset_info = huggingface_hub.dataset_info(repo_id, token=token)
421-
# Second heuristic: check the card data.
428+
# Heuristic #1: check the DatasetInfo from Hugging Face Hub/Datasets.
429+
if info_license := self._get_text_field('license'):
430+
return info_license
431+
dataset_info = self._hf_hub_info
432+
# Heuristic #2: check the card data.
422433
if dataset_info.card_data:
423434
if card_data_license := dataset_info.card_data.get('license'):
424435
return card_data_license
425-
# Third heuristic: check the tags.
436+
# Heuristic #3: check the tags.
426437
if dataset_info.tags:
427438
for tag in dataset_info.tags:
428439
if tag.startswith('license:'):
429440
return tag.removeprefix('license:')
430441
return None
431442

443+
def _get_text_field(self, field: str) -> str | None:
444+
"""Get the field from either HF Hub or HF Datasets."""
445+
# The information retrieved from the Hub has priority over the one in the
446+
# builder, because the Hub which is allegedly the new source of truth.
447+
for dataset_info in [self._hf_hub_info, self._hf_info]:
448+
# `description` and `citation` are not official fields in the Hugging Face
449+
# Hub API but they're still exposed in its __dict__.
450+
if value := getattr(dataset_info, field, None):
451+
return value
452+
return None
453+
432454

433455
def builder(
434456
name: str, config: Optional[str] = None, **builder_kwargs
@@ -443,5 +465,4 @@ def login_to_hf(hf_hub_token: Optional[str] = None):
443465
"""Logs in to Hugging Face Hub with the token as arg or env variable."""
444466
hf_hub_token = hf_hub_token or os.environ.get('HUGGING_FACE_HUB_TOKEN')
445467
if hf_hub_token is not None:
446-
huggingface_hub = lazy_imports_lib.lazy_imports.huggingface_hub
447468
huggingface_hub.login(token=hf_hub_token)

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121
from tensorflow_datasets.core import lazy_imports_lib
2222
from tensorflow_datasets.core.dataset_builders import huggingface_dataset_builder
23+
from tensorflow_datasets.core.utils.lazy_imports_utils import huggingface_hub
2324

2425
PIL_Image = lazy_imports_lib.lazy_imports.PIL_Image
2526

@@ -72,6 +73,22 @@ def mock_login_to_hf():
7273
yield login_to_hf
7374

7475

76+
@pytest.fixture(autouse=True)
77+
def mock_hub_dataset_info():
78+
fake_dataset_info = huggingface_hub.hf_api.DatasetInfo(
79+
id='foo/bar',
80+
citation='citation from the hub',
81+
private=False,
82+
downloads=123,
83+
likes=456,
84+
tags=[],
85+
)
86+
with mock.patch.object(
87+
huggingface_hub, 'dataset_info', return_value=fake_dataset_info
88+
) as dataset_info:
89+
yield dataset_info
90+
91+
7592
@pytest.fixture(name='builder')
7693
def mock_huggingface_dataset_builder(
7794
tmp_path, load_dataset_builder, login_to_hf
@@ -91,7 +108,7 @@ def mock_huggingface_dataset_builder(
91108
)
92109
login_to_hf.assert_called_once_with('SECRET_TOKEN')
93110
assert builder.info.description == 'description'
94-
assert builder.info.citation == 'citation'
111+
assert builder.info.citation == 'citation from the hub'
95112
assert builder.info.redistribution_info.license == 'test-license'
96113
yield builder
97114

0 commit comments

Comments
 (0)