Skip to content

Commit 2b9c93e

Browse files
author
The TensorFlow Datasets Authors
committed
Add assert_is_not_blocked to as_dataset and as_numpy_iterator.
PiperOrigin-RevId: 652937007
1 parent b31a385 commit 2b9c93e

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,8 @@ def as_data_source(
806806
Raises:
807807
NotImplementedError if the data was not generated using ArrayRecords.
808808
"""
809+
self.assert_is_not_blocked()
810+
809811
# By default, return all splits
810812
if split is None:
811813
split = {s: s for s in self.info.splits}
@@ -951,6 +953,8 @@ def as_dataset(
951953
If `batch_size` is -1, will return feature dictionaries containing
952954
the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`.
953955
"""
956+
self.assert_is_not_blocked()
957+
954958
# pylint: enable=line-too-long
955959
if not self.data_path.exists():
956960
raise AssertionError(

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,25 @@ def test_assert_is_not_blocked(self):
437437
assert builder_3.assert_is_not_blocked()
438438
assert not_blocked_builder.assert_is_not_blocked() is None
439439

440+
def test_blocked_as_dataset_and_as_data_source(self):
441+
for config, version, expected_msg in [
442+
("plus1", "0.0.1", "Version 0.0.1 is blocked"),
443+
("plus2", "0.0.2", "plus2 is blocked for version 0.0.2"),
444+
]:
445+
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
446+
tmp_dir = epath.Path(tmp_dir)
447+
blocked_builder = DummyDatasetWithBlockedVersions(
448+
config=config, version=version, data_dir=tmp_dir
449+
)
450+
with pytest.raises(
451+
utils.DatasetVariantBlockedError, match=expected_msg
452+
):
453+
blocked_builder.as_dataset()
454+
with pytest.raises(
455+
utils.DatasetVariantBlockedError, match=expected_msg
456+
):
457+
blocked_builder.as_data_source()
458+
440459
def test_versioned_configs(self):
441460
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
442461
tmp_dir = epath.Path(tmp_dir)

0 commit comments

Comments
 (0)