Skip to content

Commit 871cf8e

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Create a dataset provider for legacy tfds datasets
PiperOrigin-RevId: 683576101
1 parent 32790c8 commit 871cf8e

File tree

2 files changed

+38
-19
lines changed

2 files changed

+38
-19
lines changed

tensorflow_datasets/core/registered.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,26 @@ def get_builder_cls(self, name: str) -> Type[RegisteredDataset]:
300300
...
301301

302302

303+
class LegacyDatasetBuilderProvider(DatasetBuilderProvider):
304+
"""Provider of dataset builders that are defined in the legacy codebase."""
305+
306+
def has_dataset(self, name: str) -> bool:
307+
if name not in _DATASET_REGISTRY:
308+
# Dataset not found in the registry, try to import legacy builders.
309+
# Dataset builders are imported lazily to avoid slowing down the startup
310+
# of the binary.
311+
_import_legacy_builders()
312+
return name in _DATASET_REGISTRY
313+
314+
def get_builder_cls(self, name: str) -> Type[RegisteredDataset]:
315+
builder_cls = _DATASET_REGISTRY[name]
316+
if not _is_builder_available(builder_cls):
317+
available_types = visibility.get_availables()
318+
msg = f'Dataset {name} is not available. Only: {available_types}'
319+
raise PermissionError(msg)
320+
return builder_cls
321+
322+
303323
class SourceDirDatasetBuilderProvider(DatasetBuilderProvider):
304324
"""Provider of dataset builders that are defined in the given source code folder."""
305325

@@ -352,12 +372,26 @@ def get_builder_cls(self, name: str) -> Type[RegisteredDataset]:
352372

353373

354374
_DATASET_PROVIDER_REGISTRY: list[DatasetBuilderProvider] = [
355-
SourceDirDatasetBuilderProvider(constants.DATASETS_TFDS_SRC_DIR)
375+
SourceDirDatasetBuilderProvider(constants.DATASETS_TFDS_SRC_DIR),
376+
LegacyDatasetBuilderProvider(),
356377
]
357378

358379

359-
def add_dataset_builder_provider(provider: DatasetBuilderProvider) -> None:
360-
_DATASET_PROVIDER_REGISTRY.append(provider)
380+
def add_dataset_builder_provider(
381+
provider: DatasetBuilderProvider,
382+
index: int | None = None,
383+
) -> None:
384+
"""Adds a dataset builder provider to the global registry.
385+
386+
Args:
387+
provider: The provider to add.
388+
index: The index at which to insert the provider. If `None`, the provider is
389+
appended to the end of the registry.
390+
"""
391+
if index is not None:
392+
_DATASET_PROVIDER_REGISTRY.insert(index, provider)
393+
else:
394+
_DATASET_PROVIDER_REGISTRY.append(provider)
361395

362396

363397
def _is_builder_available(builder_cls: Type[RegisteredDataset]) -> bool:
@@ -429,17 +463,4 @@ def imported_builder_cls(name: str) -> Type[RegisteredDataset]:
429463
# abstract methods.
430464
raise AssertionError(f'Dataset {name} is an abstract class.')
431465

432-
if name not in _DATASET_REGISTRY:
433-
# Dataset not found in the registry, try to import legacy builders.
434-
# Dataset builders are imported lazily to avoid slowing down the startup
435-
# of the binary.
436-
_import_legacy_builders()
437-
if name not in _DATASET_REGISTRY:
438-
raise DatasetNotFoundError(f'Dataset {name} not found.')
439-
440-
builder_cls = _DATASET_REGISTRY[name]
441-
if not _is_builder_available(builder_cls):
442-
available_types = visibility.get_availables()
443-
msg = f'Dataset {name} is not available. Only: {available_types}'
444-
raise PermissionError(msg)
445-
return builder_cls # pytype: disable=bad-return-type
466+
raise DatasetNotFoundError(f'Dataset {name} not found.')

tensorflow_datasets/core/registered_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Tests for tensorflow_datasets.core.registered."""
17-
1816
import abc
1917
import re
2018
from unittest import mock

0 commit comments

Comments
 (0)