Skip to content

Commit f4c06f6

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Move _build_data_dir from dataset_builder.py to file_utils.py
PiperOrigin-RevId: 683968893
1 parent 871cf8e commit f4c06f6

File tree

4 files changed

+259
-211
lines changed

4 files changed

+259
-211
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 15 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,12 @@ def __init__(
295295
else: # Use the code version (do not restore data)
296296
self.info.initialize_from_bucket()
297297
if self.BLOCKED_VERSIONS is not None:
298-
config_name = self._builder_config.name if self._builder_config else None
299298
if is_blocked := self.BLOCKED_VERSIONS.is_blocked(
300-
version=self._version, config=config_name
299+
version=self._version, config=self.builder_config_name
301300
):
302301
default_msg = (
303302
f"Dataset {self.name} is blocked at version {self._version} and"
304-
f" config {config_name}."
303+
f" config {self.builder_config_name}."
305304
)
306305
self.info.set_is_blocked(
307306
is_blocked.blocked_msg if is_blocked.blocked_msg else default_msg
@@ -1129,73 +1128,24 @@ def _should_cache_ds(self, split, shuffle_files, read_config) -> bool:
11291128
# If the dataset satisfy all the right conditions, activate autocaching.
11301129
return True
11311130

1132-
def _relative_data_dir(self, with_version: bool = True) -> str:
1133-
"""Relative path of this dataset in data_dir."""
1134-
builder_data_dir = self.name
1135-
builder_config = self._builder_config
1136-
if builder_config:
1137-
builder_data_dir = os.path.join(builder_data_dir, builder_config.name)
1138-
if not with_version:
1139-
return builder_data_dir
1140-
1141-
version_data_dir = os.path.join(builder_data_dir, str(self._version))
1142-
return version_data_dir
1143-
1144-
def _build_data_dir(self, given_data_dir: Optional[str]):
1131+
def _build_data_dir(self, given_data_dir: str | None) -> tuple[str, str]:
11451132
"""Return the data directory for the current version.
11461133
11471134
Args:
1148-
given_data_dir: `Optional[str]`, root `data_dir` passed as `__init__`
1149-
argument.
1135+
given_data_dir: Root `data_dir` passed as `__init__` argument.
11501136
11511137
Returns:
1152-
data_dir_root: `str`, The root dir containing all datasets, downloads,...
1153-
data_dir: `str`, The version data_dir
1154-
(e.g. `<data_dir_root>/<ds_name>/<config>/<version>`)
1138+
data_dir: Root directory containing all datasets, downloads,...
1139+
dataset_dir: Dataset data directory (e.g.
1140+
`<data_dir>/<ds_name>/<config>/<version>`)
11551141
"""
1156-
builder_dir = self._relative_data_dir(with_version=False)
1157-
version_dir = self._relative_data_dir(with_version=True)
1158-
1159-
default_data_dir = file_utils.get_default_data_dir(
1160-
given_data_dir=given_data_dir
1142+
data_dir, dataset_dir = file_utils.get_data_dir_and_dataset_dir(
1143+
given_data_dir=given_data_dir,
1144+
builder_name=self.name,
1145+
config_name=self.builder_config_name,
1146+
version=self.version,
11611147
)
1162-
all_data_dirs = file_utils.list_data_dirs(given_data_dir=given_data_dir)
1163-
1164-
all_versions = set()
1165-
requested_version_dirs = {}
1166-
for data_dir_root in all_data_dirs:
1167-
# List all existing versions
1168-
full_builder_dir = os.path.join(data_dir_root, builder_dir)
1169-
data_dir_versions = set(utils.version.list_all_versions(full_builder_dir))
1170-
# Check for existence of the requested version
1171-
if self.version in data_dir_versions:
1172-
requested_version_dirs[data_dir_root] = os.path.join(
1173-
data_dir_root, version_dir
1174-
)
1175-
all_versions.update(data_dir_versions)
1176-
1177-
if len(requested_version_dirs) > 1:
1178-
raise ValueError(
1179-
"Dataset was found in more than one directory: {}. Please resolve "
1180-
"the ambiguity by explicitly specifying `data_dir=`."
1181-
"".format(requested_version_dirs.values())
1182-
)
1183-
elif len(requested_version_dirs) == 1: # The dataset is found once
1184-
return next(iter(requested_version_dirs.items()))
1185-
1186-
# No dataset found, use default directory
1187-
data_dir = os.path.join(default_data_dir, version_dir)
1188-
if all_versions:
1189-
logging.warning(
1190-
(
1191-
"Found a different version of the requested dataset:\n"
1192-
"%s\n"
1193-
"Using %s instead."
1194-
),
1195-
"\n".join(str(v) for v in sorted(all_versions)),
1196-
data_dir,
1197-
)
1198-
return default_data_dir, data_dir
1148+
return os.fspath(data_dir), os.fspath(dataset_dir)
11991149

12001150
def _log_download_done(self) -> None:
12011151
msg = (
@@ -2032,13 +1982,9 @@ def _save_default_config_name(
20321982
tmp_config_path.write_text(json.dumps(data))
20331983

20341984

2035-
def load_default_config_name(
2036-
common_dir: epath.Path,
2037-
) -> Optional[str]:
1985+
def load_default_config_name(builder_dir: epath.Path) -> str | None:
20381986
"""Load `builder_cls` metadata (common to all builder configs)."""
2039-
config_path = (
2040-
epath.Path(common_dir) / f".config/{constants.METADATA_FILENAME}"
2041-
)
1987+
config_path = builder_dir / ".config" / constants.METADATA_FILENAME
20421988
if not config_path.exists():
20431989
return None
20441990
data = json.loads(config_path.read_text())

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 18 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Tests for tensorflow_datasets.core.dataset_builder."""
17-
1816
from collections.abc import Iterator, Mapping, Sequence
1917
import dataclasses
2018
import functools
@@ -353,28 +351,34 @@ def test_build_data_dir(self):
353351
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
354352
builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
355353
self.assertEqual(str(builder.info.version), "1.0.0")
356-
builder_data_dir = os.path.join(tmp_dir, builder.name)
357-
version_dir = os.path.join(builder_data_dir, "1.0.0")
354+
355+
builder_dir = file_utils.get_dataset_dir(
356+
data_dir=tmp_dir, builder_name=builder.name
357+
)
358+
version_dir = builder_dir / "1.0.0"
358359

359360
# The dataset folder contains multiple other versions
360-
tf.io.gfile.makedirs(os.path.join(builder_data_dir, "14.0.0.invalid"))
361-
tf.io.gfile.makedirs(os.path.join(builder_data_dir, "10.0.0"))
362-
tf.io.gfile.makedirs(os.path.join(builder_data_dir, "9.0.0"))
363-
tf.io.gfile.makedirs(os.path.join(builder_data_dir, "0.1.0"))
361+
(builder_dir / "14.0.0.invalid").mkdir(parents=True, exist_ok=True)
362+
(builder_dir / "10.0.0").mkdir(parents=True, exist_ok=True)
363+
(builder_dir / "9.0.0").mkdir(parents=True, exist_ok=True)
364+
(builder_dir / "0.1.0").mkdir(parents=True, exist_ok=True)
364365

365366
# The builder's version dir is chosen
366-
self.assertEqual(builder._build_data_dir(tmp_dir)[1], version_dir)
367+
self.assertEqual(builder.data_path, version_dir)
367368

368369
def test_get_data_dir_with_config(self):
369370
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
370371
config_name = "plus1"
371372
builder = DummyDatasetWithConfigs(config=config_name, data_dir=tmp_dir)
372373

373-
builder_data_dir = os.path.join(tmp_dir, builder.name, config_name)
374-
version_data_dir = os.path.join(builder_data_dir, "0.0.1")
374+
version_dir = file_utils.get_dataset_dir(
375+
data_dir=tmp_dir,
376+
builder_name=builder.name,
377+
config_name=config_name,
378+
version="0.0.1",
379+
)
375380

376-
tf.io.gfile.makedirs(version_data_dir)
377-
self.assertEqual(builder._build_data_dir(tmp_dir)[1], version_data_dir)
381+
self.assertEqual(builder.data_path, version_dir)
378382

379383
def test_config_construction(self):
380384
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
@@ -775,12 +779,11 @@ class DatasetBuilderMultiDirTest(testing.TestCase):
775779
def setUpClass(cls):
776780
super(DatasetBuilderMultiDirTest, cls).setUpClass()
777781
cls.builder = DummyDatasetSharedGenerator()
778-
cls.version_dir = os.path.normpath(cls.builder.info.full_name)
779782

780783
def setUp(self):
781784
super(DatasetBuilderMultiDirTest, self).setUp()
782785
# Sanity check to make sure that no dir is registered
783-
self.assertEmpty(file_utils._registered_data_dir)
786+
file_utils.clear_registered_data_dirs()
784787
# Create a new temp dir
785788
self.other_data_dir = os.path.join(self.get_temp_dir(), "other_dir")
786789
# Overwrite the default data_dir (as files get created)
@@ -801,102 +804,6 @@ def tearDown(self):
801804
# Restore the orgininal data dir
802805
constants.DATA_DIR = self._original_data_dir
803806

804-
def assertBuildDataDir(self, build_data_dir_out, root_dir):
805-
data_dir_root, data_dir = build_data_dir_out
806-
self.assertEqual(data_dir_root, root_dir)
807-
self.assertEqual(data_dir, os.path.join(root_dir, self.version_dir))
808-
809-
def test_default(self):
810-
# No data_dir is passed
811-
# -> use default path is used.
812-
self.assertBuildDataDir(
813-
self.builder._build_data_dir(None), self.default_data_dir
814-
)
815-
816-
def test_explicitly_passed(self):
817-
# When a dir is explicitly passed, use it.
818-
self.assertBuildDataDir(
819-
self.builder._build_data_dir(self.other_data_dir), self.other_data_dir
820-
)
821-
822-
def test_default_multi_dir(self):
823-
# No data_dir is passed
824-
# Multiple data_dirs are registered
825-
# -> use default path
826-
file_utils.add_data_dir(self.other_data_dir)
827-
self.assertBuildDataDir(
828-
self.builder._build_data_dir(None), self.default_data_dir
829-
)
830-
831-
def test_default_multi_dir_old_version_exists(self):
832-
# No data_dir is passed
833-
# Multiple data_dirs are registered
834-
# Data dir contains old versions
835-
# -> use default path
836-
file_utils.add_data_dir(self.other_data_dir)
837-
tf.io.gfile.makedirs(
838-
os.path.join(
839-
self.other_data_dir, "dummy_dataset_shared_generator", "0.1.0"
840-
)
841-
)
842-
tf.io.gfile.makedirs(
843-
os.path.join(
844-
self.other_data_dir, "dummy_dataset_shared_generator", "0.2.0"
845-
)
846-
)
847-
self.assertBuildDataDir(
848-
self.builder._build_data_dir(None), self.default_data_dir
849-
)
850-
851-
def test_default_multi_dir_version_exists(self):
852-
# No data_dir is passed
853-
# Multiple data_dirs are registered
854-
# Data found
855-
# -> Re-load existing data
856-
file_utils.add_data_dir(self.other_data_dir)
857-
tf.io.gfile.makedirs(
858-
os.path.join(
859-
self.other_data_dir, "dummy_dataset_shared_generator", "1.0.0"
860-
)
861-
)
862-
self.assertBuildDataDir(
863-
self.builder._build_data_dir(None), self.other_data_dir
864-
)
865-
866-
def test_default_multi_dir_duplicate(self):
867-
# If two data dirs contains the dataset, raise an error...
868-
file_utils.add_data_dir(self.other_data_dir)
869-
tf.io.gfile.makedirs(
870-
os.path.join(
871-
self.default_data_dir, "dummy_dataset_shared_generator", "1.0.0"
872-
)
873-
)
874-
tf.io.gfile.makedirs(
875-
os.path.join(
876-
self.other_data_dir, "dummy_dataset_shared_generator", "1.0.0"
877-
)
878-
)
879-
with self.assertRaisesRegex(ValueError, "found in more than one directory"):
880-
self.builder._build_data_dir(None)
881-
882-
def test_explicit_multi_dir(self):
883-
# If two data dirs contains the same version
884-
# Data dir is explicitly passed
885-
file_utils.add_data_dir(self.other_data_dir)
886-
tf.io.gfile.makedirs(
887-
os.path.join(
888-
self.default_data_dir, "dummy_dataset_shared_generator", "1.0.0"
889-
)
890-
)
891-
tf.io.gfile.makedirs(
892-
os.path.join(
893-
self.other_data_dir, "dummy_dataset_shared_generator", "1.0.0"
894-
)
895-
)
896-
self.assertBuildDataDir(
897-
self.builder._build_data_dir(self.other_data_dir), self.other_data_dir
898-
)
899-
900807
def test_load_data_dir(self):
901808
"""Ensure that `tfds.load` also supports multiple data_dir."""
902809
file_utils.add_data_dir(self.other_data_dir)

0 commit comments

Comments
 (0)