Skip to content

Commit a7347c0

Browse files
author
The TensorFlow Datasets Authors
committed
Check that a dataset's version (and config) are not blocked before running download_and_prepare.
PiperOrigin-RevId: 652794653
1 parent d69632d commit a7347c0

File tree

5 files changed

+115
-17
lines changed

5 files changed

+115
-17
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,16 @@ def is_prepared(self) -> bool:
557557
"""Returns whether this dataset is already downloaded and prepared."""
558558
return self.data_path.exists()
559559

560+
def assert_is_not_blocked(self) -> None:
561+
"""Checks that the dataset is not blocked."""
562+
config_name = self.builder_config.name if self.builder_config else None
563+
if blocked_versions := self.blocked_versions:
564+
is_blocked = blocked_versions.is_blocked(
565+
version=self.version, config=config_name
566+
)
567+
if is_blocked.result:
568+
raise utils.DatasetVariantBlockedError(is_blocked.blocked_msg)
569+
560570
@tfds_logging.download_and_prepare()
561571
def download_and_prepare(
562572
self,
@@ -581,7 +591,11 @@ def download_and_prepare(
581591
Raises:
582592
IOError: if there is not enough disk space available.
583593
RuntimeError: when the config cannot be found.
594+
DatasetBlockedError: if the given version, or combination of version and
595+
config, has been marked as blocked in the builder's BLOCKED_VERSIONS.
584596
"""
597+
self.assert_is_not_blocked()
598+
585599

586600
download_config = download_config or download.DownloadConfig()
587601
data_path = self.data_path

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ def _generate_examples(self, range_):
9494
yield i, {"x": x}
9595

9696

97+
class DummyDatasetWithBlockedVersions(DummyDatasetWithConfigs):
98+
99+
BLOCKED_VERSIONS = utils.BlockedVersions(
100+
versions={"0.0.1": "Version 0.0.1 is blocked"},
101+
configs={"0.0.2": {"plus2": "plus2 is blocked for version 0.0.2"}},
102+
)
103+
104+
97105
class DummyDatasetWithDefaultConfig(DummyDatasetWithConfigs):
98106
DEFAULT_BUILDER_CONFIG_NAME = "plus2"
99107

@@ -393,6 +401,42 @@ def test_builder_configs_configs_with_multiple_versions(self):
393401
set(DummyDatasetWithVersionedConfigs.builder_configs.keys()),
394402
)
395403

404+
def test_assert_is_not_blocked(self):
405+
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
406+
tmp_dir = epath.Path(tmp_dir)
407+
builder_1 = DummyDatasetWithBlockedVersions(
408+
config="plus1", version="0.0.1", data_dir=tmp_dir
409+
)
410+
builder_2 = DummyDatasetWithBlockedVersions(
411+
config="plus2", version="0.0.2", data_dir=tmp_dir
412+
)
413+
builder_3 = DummyDatasetWithBlockedVersions(
414+
config="plus2", version="0.0.1", data_dir=tmp_dir
415+
)
416+
not_blocked_builder = DummyDatasetWithConfigs(
417+
config="plus1", version="0.0.1", data_dir=tmp_dir
418+
)
419+
420+
assert builder_1.blocked_versions is not None
421+
assert builder_2.blocked_versions is not None
422+
assert builder_3.blocked_versions is not None
423+
assert not_blocked_builder.blocked_versions is None
424+
425+
with pytest.raises(
426+
utils.DatasetVariantBlockedError, match="Version 0.0.1 is blocked"
427+
):
428+
assert builder_1.assert_is_not_blocked()
429+
with pytest.raises(
430+
utils.DatasetVariantBlockedError,
431+
match="plus2 is blocked for version 0.0.2",
432+
):
433+
assert builder_2.assert_is_not_blocked()
434+
with pytest.raises(
435+
utils.DatasetVariantBlockedError, match="Version 0.0.1 is blocked"
436+
):
437+
assert builder_3.assert_is_not_blocked()
438+
assert not_blocked_builder.assert_is_not_blocked() is None
439+
396440
def test_versioned_configs(self):
397441
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
398442
tmp_dir = epath.Path(tmp_dir)

tensorflow_datasets/core/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,6 @@
8484
from tensorflow_datasets.core.utils.tqdm_utils import TqdmStream
8585
from tensorflow_datasets.core.utils.type_utils import *
8686
from tensorflow_datasets.core.utils.version import BlockedVersions
87+
from tensorflow_datasets.core.utils.version import DatasetVariantBlockedError
8788
from tensorflow_datasets.core.utils.version import Experiment
8889
from tensorflow_datasets.core.utils.version import Version

tensorflow_datasets/core/utils/version.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,28 @@
3232
_VERSION_RESOLVED_REG = re.compile(_VERSION_TMPL.format(v=_NO_LEADING_ZEROS))
3333

3434

35+
class DatasetVariantBlockedError(ValueError):
36+
"""Exception raised when a blocked version and/or config is requested."""
37+
38+
3539
# A dictionary of blocked versions or configs.
3640
# The key is a version or config string, the value is a short sentence
3741
# explaining why that version or config should not be used (or None).
3842
BlockedWithMsg = dict[str, str | None]
3943

4044

45+
@dataclasses.dataclass(frozen=True)
46+
class IsBlocked:
47+
"""Class to store information about a version or config being blocked.
48+
49+
Also contains an optional message explaining why the version or config is
50+
blocked.
51+
"""
52+
53+
result: bool
54+
blocked_msg: str | None = None
55+
56+
4157
@dataclasses.dataclass(frozen=True)
4258
class BlockedVersions:
4359
"""Holds information on versions and configs that should not be used.
@@ -56,15 +72,31 @@ class BlockedVersions:
5672

5773
def is_blocked(
5874
self, version: str | Version, config: str | None = None
59-
) -> bool:
60-
"""Checks whether a version or config is blocked."""
75+
) -> IsBlocked:
76+
"""Checks whether a version or config is blocked.
77+
78+
Args:
79+
version: The version to check.
80+
config: The config to check. If None, the version is checked.
81+
82+
Returns:
83+
An IsBlocked object. If IsBlocked.result is True, IsBlocked.blocked_msg
84+
contains the message explaining why the version or config is blocked, if
85+
it exists, or a default message otherwise.
86+
"""
6187
if isinstance(version, Version):
6288
version = str(version)
6389
if version in self.versions:
64-
return True
90+
blocked_msg = self.versions[version] or f"Version {version} is blocked."
91+
return IsBlocked(True, blocked_msg)
6592
if config is not None and version in self.configs:
66-
return config in self.configs[version]
67-
return False
93+
if config in self.configs[version]:
94+
blocked_msg = (
95+
self.configs[version][config]
96+
or f"Config {config} for version {version} is blocked."
97+
)
98+
return IsBlocked(True, blocked_msg)
99+
return IsBlocked(False)
68100

69101

70102
class Experiment(enum.Enum):

tensorflow_datasets/core/utils/version_test.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,28 +151,35 @@ def test_str_to_version():
151151

152152

153153
@pytest.mark.parametrize(
154-
'blocked_version, blocked_config, expected',
154+
'blocked_version, blocked_config, expected_res, expected_msg',
155155
[
156-
('1.2.3', None, True),
157-
('1.0.0', None, False),
158-
('1.2.3', 'non_existing_config', True),
159-
('1.0.0', 'config_1', True),
160-
('1.0.0', 'config_2', True),
161-
('1.0.0', 'config_1', True),
162-
('1.1.0', 'config_2', False),
156+
('1.2.3', None, True, 'Version 1.2.3 is blocked.'),
157+
('1.0.0', None, False, None),
158+
('1.2.3', 'non_existing_config', True, 'Version 1.2.3 is blocked.'),
159+
('1.0.0', 'config_1', True, 'blocked_config'),
160+
(
161+
'1.0.0',
162+
'config_2',
163+
True,
164+
'Config config_2 for version 1.0.0 is blocked.',
165+
),
166+
('1.1.0', 'config_1', True, 'blocked config in version 1.1.0'),
167+
('1.1.0', 'config_2', False, None),
163168
],
164169
)
165-
def test_is_blocked(blocked_version, blocked_config, expected):
170+
def test_is_blocked(
171+
blocked_version, blocked_config, expected_res, expected_msg
172+
):
166173
blocked_versions = version.BlockedVersions(
167174
versions={'1.2.3': None},
168175
configs={
169176
'1.0.0': {'config_1': 'blocked_config', 'config_2': None},
170177
'1.1.0': {'config_1': 'blocked config in version 1.1.0'},
171178
},
172179
)
173-
assert (
174-
blocked_versions.is_blocked(blocked_version, blocked_config) == expected
175-
)
180+
is_blocked = blocked_versions.is_blocked(blocked_version, blocked_config)
181+
assert is_blocked.result == expected_res
182+
assert is_blocked.blocked_msg == expected_msg
176183

177184

178185
if __name__ == '__main__':

0 commit comments

Comments
 (0)