Skip to content

Commit 960a2df

Browse files
author
The TensorFlow Datasets Authors
committed
Add a functionality to reset the list of dataset builder providers.BEGIN PUBLIC
PiperOrigin-RevId: 684549116
1 parent 64c483b commit 960a2df

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

tensorflow_datasets/core/registered.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,16 @@ def get_builder_cls(self, name: str) -> Type[RegisteredDataset]:
371371
return cls
372372

373373

374-
_DATASET_PROVIDER_REGISTRY: list[DatasetBuilderProvider] = [
375-
SourceDirDatasetBuilderProvider(constants.DATASETS_TFDS_SRC_DIR),
376-
LegacyDatasetBuilderProvider(),
377-
]
374+
def _get_inital_providers() -> list[DatasetBuilderProvider]:
375+
return [
376+
SourceDirDatasetBuilderProvider(constants.DATASETS_TFDS_SRC_DIR),
377+
LegacyDatasetBuilderProvider(),
378+
]
379+
380+
381+
_DATASET_PROVIDER_REGISTRY: list[DatasetBuilderProvider] = (
382+
_get_inital_providers()
383+
)
378384

379385

380386
def add_dataset_builder_provider(
@@ -394,6 +400,12 @@ def add_dataset_builder_provider(
394400
_DATASET_PROVIDER_REGISTRY.append(provider)
395401

396402

403+
def reset_dataset_builder_providers() -> None:
404+
"""Resets the list of dataset builder providers to remove added providers."""
405+
global _DATASET_PROVIDER_REGISTRY
406+
_DATASET_PROVIDER_REGISTRY = _get_inital_providers()
407+
408+
397409
def _is_builder_available(builder_cls: Type[RegisteredDataset]) -> bool:
398410
"""Returns `True` is the builder is available."""
399411
return visibility.DatasetType.TFDS_PUBLIC.is_available()

tensorflow_datasets/core/registered_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,28 @@ class MyDataset(registered.RegisteredDataset): # pylint: disable=unused-variabl
287287
builder_cls = registered.imported_builder_cls("my_dataset")
288288
self.assertEqual(builder_cls, EmptyDatasetBuilder)
289289

290+
def test_reset_dataset_builder_providers(self):
291+
"""Resetting dataset builder providers."""
292+
293+
class MyRegisteredDataset(registered.RegisteredDataset): # pylint: disable=unused-variable
294+
pass
295+
296+
registered.add_dataset_builder_provider(
297+
TestBuilderProvider("my_registered_dataset", EmptyDatasetBuilder),
298+
0, # start of the list
299+
)
300+
builder_cls = registered.imported_builder_cls("my_registered_dataset")
301+
self.assertEqual(builder_cls, EmptyDatasetBuilder)
302+
303+
registered.reset_dataset_builder_providers()
304+
305+
registered.add_dataset_builder_provider(
306+
TestBuilderProvider("my_registered_dataset", EmptyDatasetBuilder),
307+
None, # end of the list
308+
)
309+
builder_cls = registered.imported_builder_cls("my_registered_dataset")
310+
self.assertEqual(builder_cls, MyRegisteredDataset)
311+
290312

291313
def test_skip_regitration():
292314
"""Test `skip_registration()`."""

0 commit comments

Comments
 (0)