Skip to content

Commit 6d37846

Browse files
author
The TensorFlow Datasets Authors
committed
Add checks for TfdsDataSource to ensure the task is not blocked.
PiperOrigin-RevId: 653203660
1 parent 5381a14 commit 6d37846

File tree

5 files changed

+39
-4
lines changed

5 files changed

+39
-4
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,15 @@ def is_prepared(self) -> bool:
557557
"""Returns whether this dataset is already downloaded and prepared."""
558558
return self.data_path.exists()
559559

560+
def is_blocked(self) -> utils.IsBlocked:
561+
"""Returns whether this builder (version, config) is blocked."""
562+
config_name = self.builder_config.name if self.builder_config else None
563+
if blocked_versions := self.blocked_versions:
564+
return blocked_versions.is_blocked(
565+
version=self.version, config=config_name
566+
)
567+
return utils.IsBlocked(False)
568+
560569
def assert_is_not_blocked(self) -> None:
561570
"""Checks that the dataset is not blocked."""
562571
config_name = self.builder_config.name if self.builder_config else None

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,22 @@ def test_builder_configs_configs_with_multiple_versions(self):
401401
set(DummyDatasetWithVersionedConfigs.builder_configs.keys()),
402402
)
403403

404+
def test_is_blocked(self):
405+
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
406+
tmp_dir = epath.Path(tmp_dir)
407+
builder_1 = DummyDatasetWithBlockedVersions(
408+
config="plus1", version="0.0.1", data_dir=tmp_dir
409+
)
410+
builder_2 = DummyDatasetWithBlockedVersions(
411+
config="plus2", version="0.0.2", data_dir=tmp_dir
412+
)
413+
not_blocked_builder = DummyDatasetWithConfigs(
414+
config="plus1", version="0.0.1", data_dir=tmp_dir
415+
)
416+
assert builder_1.is_blocked()
417+
assert builder_2.is_blocked()
418+
assert not not_blocked_builder.is_blocked()
419+
404420
def test_assert_is_not_blocked(self):
405421
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
406422
tmp_dir = epath.Path(tmp_dir)
@@ -736,7 +752,7 @@ def test_default(self):
736752
)
737753

738754
def test_explicitly_passed(self):
739-
# When a dir is explictly passed, use it.
755+
# When a dir is explicitly passed, use it.
740756
self.assertBuildDataDir(
741757
self.builder._build_data_dir(self.other_data_dir), self.other_data_dir
742758
)
@@ -801,7 +817,7 @@ def test_default_multi_dir_duplicate(self):
801817
with self.assertRaisesRegex(ValueError, "found in more than one directory"):
802818
self.builder._build_data_dir(None)
803819

804-
def test_expicit_multi_dir(self):
820+
def test_explicit_multi_dir(self):
805821
# If two data dirs contains the same version
806822
# Data dir is explicitly passed
807823
file_utils.add_data_dir(self.other_data_dir)
@@ -979,7 +995,7 @@ def test_update_dataset_info_keeps_data_source(
979995
assert len(info_proto.data_source_accesses) == 1
980996
assert info_proto.data_source_accesses[0].file_system.path == "/x/y"
981997
builder.download_and_prepare()
982-
# Manually check information was indeed written in datset_info.json and
998+
# Manually check information was indeed written in dataset_info.json and
983999
# can be reloaded:
9841000
builder = testing.DummyMnist(data_dir=tmp_dir)
9851001
info_proto = builder.info.as_proto

tensorflow_datasets/core/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,5 @@
8686
from tensorflow_datasets.core.utils.version import BlockedVersions
8787
from tensorflow_datasets.core.utils.version import DatasetVariantBlockedError
8888
from tensorflow_datasets.core.utils.version import Experiment
89+
from tensorflow_datasets.core.utils.version import IsBlocked
8990
from tensorflow_datasets.core.utils.version import Version

tensorflow_datasets/core/utils/version.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@ class IsBlocked:
5050
blocked.
5151
"""
5252

53-
result: bool
53+
result: bool = False
5454
blocked_msg: str | None = None
5555

56+
def __bool__(self) -> bool:
57+
return self.result
58+
5659

5760
@dataclasses.dataclass(frozen=True)
5861
class BlockedVersions:

tensorflow_datasets/core/utils/version_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ def test_str_to_version():
150150
assert v1 == v0
151151

152152

153+
@pytest.mark.parametrize('is_blocked_result', [True, False])
154+
def test_is_blocked_bool(is_blocked_result):
155+
is_blocked = version.IsBlocked(is_blocked_result)
156+
assert bool(is_blocked) == is_blocked_result
157+
158+
153159
@pytest.mark.parametrize(
154160
'blocked_version, blocked_config, expected_res, expected_msg',
155161
[

0 commit comments

Comments
 (0)