Skip to content

Commit c388d5f

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add a function to instantiate a DatasetReference from a dataset dir
PiperOrigin-RevId: 686902586
1 parent a15ea31 commit c388d5f

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed

tensorflow_datasets/core/naming.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ def from_tfds_name(
292292
) -> DatasetReference:
293293
"""Returns the `DatasetReference` for the given TFDS dataset."""
294294
parsed_name, builder_kwargs = parse_builder_name_kwargs(tfds_name)
295-
version, config = None, None
296295
version = builder_kwargs.get('version')
297296
config = builder_kwargs.get('config')
298297
return cls(
@@ -304,6 +303,46 @@ def from_tfds_name(
304303
data_dir=data_dir,
305304
)
306305

306+
@classmethod
307+
def from_path(
308+
cls,
309+
dataset_dir: epath.PathLike,
310+
root_data_dir: epath.PathLike,
311+
) -> DatasetReference:
312+
"""Returns the `DatasetReference` for the given dataset directory.
313+
314+
Args:
315+
dataset_dir: The path to the dataset directory, e.g.,
316+
`/data/my_dataset/my_config/1.2.3`.
317+
root_data_dir: The root data directory, e.g., `/data`.
318+
"""
319+
dataset_dir = os.fspath(dataset_dir)
320+
root_data_dir = os.fspath(root_data_dir)
321+
322+
if not dataset_dir.startswith(root_data_dir):
323+
raise ValueError(f'{dataset_dir=} does not start with {root_data_dir=}!')
324+
325+
relative_path = dataset_dir.removeprefix(root_data_dir)
326+
relative_path = relative_path.removeprefix('/').removesuffix('/')
327+
parts = relative_path.split('/')
328+
dataset_name = parts[0]
329+
if len(parts) == 2:
330+
config_name = None
331+
version = parts[1]
332+
elif len(parts) == 3:
333+
config_name = parts[1]
334+
version = parts[2]
335+
else:
336+
raise ValueError(
337+
f'Invalid {relative_path=} for {root_data_dir=} and {dataset_dir=}'
338+
)
339+
return cls(
340+
dataset_name=dataset_name,
341+
config=config_name,
342+
version=version,
343+
data_dir=root_data_dir.removesuffix('/'),
344+
)
345+
307346

308347
def references_for(
309348
name_to_tfds_name: Mapping[str, str],

tensorflow_datasets/core/naming_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,100 @@ def test_dataset_reference_from_tfds_name(
862862
)
863863

864864

865+
@pytest.mark.parametrize(
866+
('dataset_dir', 'root_data_dir', 'expected'),
867+
[
868+
# Dataset with a config and a version.
869+
(
870+
'/data/ds/config/1.2.3',
871+
'/data',
872+
naming.DatasetReference(
873+
dataset_name='ds',
874+
version='1.2.3',
875+
config='config',
876+
data_dir='/data',
877+
),
878+
),
879+
# Dataset with no config and a version.
880+
(
881+
'/data/ds/1.2.3',
882+
'/data',
883+
naming.DatasetReference(
884+
dataset_name='ds',
885+
version='1.2.3',
886+
config=None,
887+
data_dir='/data',
888+
),
889+
),
890+
# Dataset dir with trailing slash.
891+
(
892+
'/data/ds/config/1.2.3/',
893+
'/data',
894+
naming.DatasetReference(
895+
dataset_name='ds',
896+
version='1.2.3',
897+
config='config',
898+
data_dir='/data',
899+
),
900+
),
901+
# Root data dir with trailing slash.
902+
(
903+
'/data/ds/config/1.2.3',
904+
'/data/',
905+
naming.DatasetReference(
906+
dataset_name='ds',
907+
version='1.2.3',
908+
config='config',
909+
data_dir='/data',
910+
),
911+
),
912+
# Dataset dir and root data dir with trailing slash.
913+
(
914+
'/data/ds/config/1.2.3/',
915+
'/data/',
916+
naming.DatasetReference(
917+
dataset_name='ds',
918+
version='1.2.3',
919+
config='config',
920+
data_dir='/data',
921+
),
922+
),
923+
],
924+
)
925+
def test_dataset_reference_from_path(dataset_dir, root_data_dir, expected):
926+
actual = naming.DatasetReference.from_path(
927+
dataset_dir=dataset_dir, root_data_dir=root_data_dir
928+
)
929+
assert actual == expected
930+
931+
932+
@pytest.mark.parametrize(
933+
('dataset_dir', 'root_data_dir'),
934+
[
935+
# Root data dir is not a prefix of the dataset dir.
936+
(
937+
'/data/ds/config/1.2.3',
938+
'/somewhere_else',
939+
),
940+
# Too many nested folders.
941+
(
942+
'/data/ds/config/another_folder/1.2.3',
943+
'/data',
944+
),
945+
# Too few nested folders.
946+
(
947+
'/data/ds/',
948+
'/data',
949+
),
950+
],
951+
)
952+
def test_dataset_reference_from_path_invalid(dataset_dir, root_data_dir):
953+
with pytest.raises(ValueError):
954+
naming.DatasetReference.from_path(
955+
dataset_dir=dataset_dir, root_data_dir=root_data_dir
956+
)
957+
958+
865959
@pytest.mark.parametrize(
866960
('ds_name', 'namespace', 'version', 'config', 'tfds_name'),
867961
[

0 commit comments

Comments
 (0)