Skip to content

Commit d2ae526

Browse files
author
The TensorFlow Datasets Authors
committed
Add unit tests with same name builders from different providers.
PiperOrigin-RevId: 684015424
1 parent f4c06f6 commit d2ae526

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

tensorflow_datasets/core/registered_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,28 @@
2727
from tensorflow_datasets.testing.dummy_config_based_datasets.dummy_ds_1 import dummy_ds_1_dataset_builder
2828

2929

30+
class TestBuilderProvider(tfds.core.DatasetBuilderProvider):
31+
"""Test Builder provider."""
32+
33+
def __init__(
34+
self,
35+
name: str,
36+
dataset_builder_cls: type[tfds.core.registered.RegisteredDataset],
37+
):
38+
self._builder_cls_cache = {name: dataset_builder_cls}
39+
self._available_datasets = [name]
40+
41+
def has_dataset(self, name: str) -> bool:
42+
return name in self._available_datasets
43+
44+
def get_builder_cls(
45+
self, name: str
46+
) -> type[tfds.core.registered.RegisteredDataset]:
47+
if not self.has_dataset(name):
48+
raise ValueError(f"Dataset {name} is not available!")
49+
return self._builder_cls_cache[name]
50+
51+
3052
class EmptyDatasetBuilder(registered.RegisteredDataset):
3153

3254
def __init__(self, **kwargs):
@@ -239,6 +261,32 @@ def test_is_full_name(self):
239261
self.assertTrue(load.is_full_name("ds/1.0.2"))
240262
self.assertTrue(load.is_full_name("ds_with_number123/1.0.2"))
241263

264+
def test_add_dataset_provider_to_end(self):
265+
"""Adding same name dataset through dataset provider to the end of the list."""
266+
267+
class OriginalDataset(registered.RegisteredDataset): # pylint: disable=unused-variable
268+
pass
269+
270+
registered.add_dataset_builder_provider(
271+
TestBuilderProvider("original_dataset", EmptyDatasetBuilder),
272+
None, # end of the list
273+
)
274+
builder_cls = registered.imported_builder_cls("original_dataset")
275+
self.assertEqual(builder_cls, OriginalDataset)
276+
277+
def test_add_dataset_provider_to_start(self):
278+
"""Adding same name dataset through dataset provider to the start of the list."""
279+
280+
class MyDataset(registered.RegisteredDataset): # pylint: disable=unused-variable
281+
pass
282+
283+
registered.add_dataset_builder_provider(
284+
TestBuilderProvider("my_dataset", EmptyDatasetBuilder),
285+
0, # start of the list
286+
)
287+
builder_cls = registered.imported_builder_cls("my_dataset")
288+
self.assertEqual(builder_cls, EmptyDatasetBuilder)
289+
242290

243291
def test_skip_regitration():
244292
"""Test `skip_registration()`."""

0 commit comments

Comments
 (0)