|
27 | 27 | from tensorflow_datasets.testing.dummy_config_based_datasets.dummy_ds_1 import dummy_ds_1_dataset_builder
|
28 | 28 |
|
29 | 29 |
|
| 30 | +class TestBuilderProvider(tfds.core.DatasetBuilderProvider): |
| 31 | + """Test Builder provider.""" |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + name: str, |
| 36 | + dataset_builder_cls: type[tfds.core.registered.RegisteredDataset], |
| 37 | + ): |
| 38 | + self._builder_cls_cache = {name: dataset_builder_cls} |
| 39 | + self._available_datasets = [name] |
| 40 | + |
| 41 | + def has_dataset(self, name: str) -> bool: |
| 42 | + return name in self._available_datasets |
| 43 | + |
| 44 | + def get_builder_cls( |
| 45 | + self, name: str |
| 46 | + ) -> type[tfds.core.registered.RegisteredDataset]: |
| 47 | + if not self.has_dataset(name): |
| 48 | + raise ValueError(f"Dataset {name} is not available!") |
| 49 | + return self._builder_cls_cache[name] |
| 50 | + |
| 51 | + |
30 | 52 | class EmptyDatasetBuilder(registered.RegisteredDataset):
|
31 | 53 |
|
32 | 54 | def __init__(self, **kwargs):
|
@@ -239,6 +261,32 @@ def test_is_full_name(self):
|
239 | 261 | self.assertTrue(load.is_full_name("ds/1.0.2"))
|
240 | 262 | self.assertTrue(load.is_full_name("ds_with_number123/1.0.2"))
|
241 | 263 |
|
| 264 | + def test_add_dataset_provider_to_end(self): |
| 265 | + """Adding same name dataset through dataset provider to the end of the list.""" |
| 266 | + |
| 267 | + class OriginalDataset(registered.RegisteredDataset): # pylint: disable=unused-variable |
| 268 | + pass |
| 269 | + |
| 270 | + registered.add_dataset_builder_provider( |
| 271 | + TestBuilderProvider("original_dataset", EmptyDatasetBuilder), |
| 272 | + None, # end of the list |
| 273 | + ) |
| 274 | + builder_cls = registered.imported_builder_cls("original_dataset") |
| 275 | + self.assertEqual(builder_cls, OriginalDataset) |
| 276 | + |
| 277 | + def test_add_dataset_provider_to_start(self): |
| 278 | + """Adding same name dataset through dataset provider to the start of the list.""" |
| 279 | + |
| 280 | + class MyDataset(registered.RegisteredDataset): # pylint: disable=unused-variable |
| 281 | + pass |
| 282 | + |
| 283 | + registered.add_dataset_builder_provider( |
| 284 | + TestBuilderProvider("my_dataset", EmptyDatasetBuilder), |
| 285 | + 0, # start of the list |
| 286 | + ) |
| 287 | + builder_cls = registered.imported_builder_cls("my_dataset") |
| 288 | + self.assertEqual(builder_cls, EmptyDatasetBuilder) |
| 289 | + |
242 | 290 |
|
243 | 291 | def test_skip_regitration():
|
244 | 292 | """Test `skip_registration()`."""
|
|
0 commit comments