File tree Expand file tree Collapse file tree 2 files changed +38
-4
lines changed Expand file tree Collapse file tree 2 files changed +38
-4
lines changed Original file line number Diff line number Diff line change @@ -371,10 +371,16 @@ def get_builder_cls(self, name: str) -> Type[RegisteredDataset]:
371
371
return cls
372
372
373
373
374
- _DATASET_PROVIDER_REGISTRY : list [DatasetBuilderProvider ] = [
375
- SourceDirDatasetBuilderProvider (constants .DATASETS_TFDS_SRC_DIR ),
376
- LegacyDatasetBuilderProvider (),
377
- ]
374
+ def _get_inital_providers () -> list [DatasetBuilderProvider ]:
375
+ return [
376
+ SourceDirDatasetBuilderProvider (constants .DATASETS_TFDS_SRC_DIR ),
377
+ LegacyDatasetBuilderProvider (),
378
+ ]
379
+
380
+
381
+ _DATASET_PROVIDER_REGISTRY : list [DatasetBuilderProvider ] = (
382
+ _get_inital_providers ()
383
+ )
378
384
379
385
380
386
def add_dataset_builder_provider (
@@ -394,6 +400,12 @@ def add_dataset_builder_provider(
394
400
_DATASET_PROVIDER_REGISTRY .append (provider )
395
401
396
402
403
+ def reset_dataset_builder_providers () -> None :
404
+ """Resets the list of dataset builder providers to remove added providers."""
405
+ global _DATASET_PROVIDER_REGISTRY
406
+ _DATASET_PROVIDER_REGISTRY = _get_inital_providers ()
407
+
408
+
397
409
def _is_builder_available (builder_cls : Type [RegisteredDataset ]) -> bool :
398
410
"""Returns `True` is the builder is available."""
399
411
return visibility .DatasetType .TFDS_PUBLIC .is_available ()
Original file line number Diff line number Diff line change @@ -287,6 +287,28 @@ class MyDataset(registered.RegisteredDataset): # pylint: disable=unused-variabl
287
287
builder_cls = registered .imported_builder_cls ("my_dataset" )
288
288
self .assertEqual (builder_cls , EmptyDatasetBuilder )
289
289
290
+ def test_reset_dataset_builder_providers (self ):
291
+ """Resetting dataset builder providers."""
292
+
293
+ class MyRegisteredDataset (registered .RegisteredDataset ): # pylint: disable=unused-variable
294
+ pass
295
+
296
+ registered .add_dataset_builder_provider (
297
+ TestBuilderProvider ("my_registered_dataset" , EmptyDatasetBuilder ),
298
+ 0 , # start of the list
299
+ )
300
+ builder_cls = registered .imported_builder_cls ("my_registered_dataset" )
301
+ self .assertEqual (builder_cls , EmptyDatasetBuilder )
302
+
303
+ registered .reset_dataset_builder_providers ()
304
+
305
+ registered .add_dataset_builder_provider (
306
+ TestBuilderProvider ("my_registered_dataset" , EmptyDatasetBuilder ),
307
+ None , # end of the list
308
+ )
309
+ builder_cls = registered .imported_builder_cls ("my_registered_dataset" )
310
+ self .assertEqual (builder_cls , MyRegisteredDataset )
311
+
290
312
291
313
def test_skip_regitration ():
292
314
"""Test `skip_registration()`."""
You can’t perform that action at this time.
0 commit comments