Skip to content

Commit e51d412

Browse files
author
The TensorFlow Datasets Authors
committed
Add BlockedVersions to document which TFDS versions and configs should not be used.
PiperOrigin-RevId: 652441832
1 parent 24c1e98 commit e51d412

File tree

4 files changed

+71
-0
lines changed

4 files changed

+71
-0
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ class DatasetBuilder(registered.RegisteredDataset):
229229
# Example: `imported_builder_cls` function in `registered.py` module sets it.
230230
pkg_dir_path: Optional[epath.Path] = None
231231

232+
# Holds information on versions and configs that should not be used.
233+
BLOCKED_VERSIONS: ClassVar[Optional[utils.BlockedVersions]] = None
234+
232235
@classmethod
233236
def _get_pkg_dir_path(cls) -> epath.Path:
234237
"""Returns class pkg_dir_path, infer it first if not set."""
@@ -396,6 +399,10 @@ def release_notes(self) -> Dict[str, str]:
396399
else:
397400
return self.RELEASE_NOTES
398401

402+
@property
403+
def blocked_versions(self) -> utils.BlockedVersions | None:
404+
return self.BLOCKED_VERSIONS
405+
399406
@property
400407
def data_dir_root(self) -> epath.Path:
401408
"""Returns the root directory where all TFDS datasets are stored.

tensorflow_datasets/core/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,6 @@
8383
from tensorflow_datasets.core.utils.tqdm_utils import tqdm
8484
from tensorflow_datasets.core.utils.tqdm_utils import TqdmStream
8585
from tensorflow_datasets.core.utils.type_utils import *
86+
from tensorflow_datasets.core.utils.version import BlockedVersions
8687
from tensorflow_datasets.core.utils.version import Experiment
8788
from tensorflow_datasets.core.utils.version import Version

tensorflow_datasets/core/utils/version.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from __future__ import annotations
1919

20+
import dataclasses
2021
import enum
2122
import re
2223
from typing import List, Tuple, Union
@@ -31,6 +32,41 @@
3132
_VERSION_RESOLVED_REG = re.compile(_VERSION_TMPL.format(v=_NO_LEADING_ZEROS))
3233

3334

35+
# A dictionary of blocked versions or configs.
36+
# The key is a version or config string, the value is a short sentence
37+
# explaining why that version or config should not be used (or None).
38+
BlockedWithMsg = dict[str, str | None]
39+
40+
41+
@dataclasses.dataclass(frozen=True)
42+
class BlockedVersions:
43+
"""Holds information on versions and configs that should not be used.
44+
45+
Note that only complete versions can be blocked: wilcards in versions are not
46+
supported.
47+
48+
versions: A dictionary of bad versions for which all configs should be
49+
blocked.
50+
configs: A mapping from versions to a dictionary of configs that should not be
51+
used for that version.
52+
"""
53+
54+
versions: BlockedWithMsg = dataclasses.field(default_factory=dict)
55+
configs: dict[str, BlockedWithMsg] = dataclasses.field(default_factory=dict)
56+
57+
def is_blocked(
58+
self, version: str | Version, config: str | None = None
59+
) -> bool:
60+
"""Checks whether a version or config is blocked."""
61+
if isinstance(version, Version):
62+
version = str(version)
63+
if version in self.versions:
64+
return True
65+
if config is not None and version in self.configs:
66+
return config in self.configs[version]
67+
return False
68+
69+
3470
class Experiment(enum.Enum):
3571
"""Experiments which can be enabled/disabled on a per version basis.
3672

tensorflow_datasets/core/utils/version_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515

1616
"""Tests for tensorflow_datasets.core.utils.version."""
17+
18+
import pytest
1719
from tensorflow_datasets import testing
1820
from tensorflow_datasets.core.utils import version
1921

@@ -148,5 +150,30 @@ def test_str_to_version():
148150
assert v1 == v0
149151

150152

153+
@pytest.mark.parametrize(
154+
'blocked_version, blocked_config, expected',
155+
[
156+
('1.2.3', None, True),
157+
('1.0.0', None, False),
158+
('1.2.3', 'non_existing_config', True),
159+
('1.0.0', 'config_1', True),
160+
('1.0.0', 'config_2', True),
161+
('1.0.0', 'config_1', True),
162+
('1.1.0', 'config_2', False),
163+
],
164+
)
165+
def test_is_blocked(blocked_version, blocked_config, expected):
166+
blocked_versions = version.BlockedVersions(
167+
versions={'1.2.3': None},
168+
configs={
169+
'1.0.0': {'config_1': 'blocked_config', 'config_2': None},
170+
'1.1.0': {'config_1': 'blocked config in version 1.1.0'},
171+
},
172+
)
173+
assert (
174+
blocked_versions.is_blocked(blocked_version, blocked_config) == expected
175+
)
176+
177+
151178
if __name__ == '__main__':
152179
testing.test_main()

0 commit comments

Comments
 (0)