Skip to content

Commit 8ea2e38

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Move file_utils.get_default_data_dir() to constants.get_default_data_dir().
PiperOrigin-RevId: 796869816
1 parent 6b4aad6 commit 8ea2e38

File tree

8 files changed

+38
-43
lines changed

8 files changed

+38
-43
lines changed

tensorflow_datasets/core/constants.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,17 @@
2828

2929
# Directory where to store processed datasets.
3030
# If modifying this, should also update `scripts/cli/build.py` `--data_dir`
31-
DATA_DIR: Final[str] = os.environ.get(
32-
'TFDS_DATA_DIR',
33-
os.path.join(os.path.expanduser('~'), 'tensorflow_datasets'),
34-
)
31+
32+
33+
def get_default_data_dir() -> str:
34+
"""Returns the TFDS default data directory."""
35+
return os.environ.get(
36+
'TFDS_DATA_DIR',
37+
os.path.join(os.path.expanduser('~'), 'tensorflow_datasets'),
38+
)
39+
40+
41+
DATA_DIR: Final[str] = get_default_data_dir()
3542

3643
# Prefix of files / directories which aren't finished downloading / extracting.
3744
INCOMPLETE_PREFIX = 'incomplete.'

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import pytest
2828
import tensorflow as tf
2929
from tensorflow_datasets import testing
30-
from tensorflow_datasets.core import constants
3130
from tensorflow_datasets.core import dataset_builder
3231
from tensorflow_datasets.core import dataset_info
3332
from tensorflow_datasets.core import dataset_utils
@@ -831,32 +830,19 @@ class DatasetBuilderMultiDirTest(testing.TestCase):
831830

832831
@classmethod
833832
def setUpClass(cls):
834-
super(DatasetBuilderMultiDirTest, cls).setUpClass()
833+
super().setUpClass()
835834
cls.builder = DummyDatasetSharedGenerator()
836835

837836
def setUp(self):
838-
super(DatasetBuilderMultiDirTest, self).setUp()
837+
super().setUp()
839838
# Sanity check to make sure that no dir is registered
840839
file_utils.clear_registered_data_dirs()
841-
# Create a new temp dir
842-
self.other_data_dir = os.path.join(self.get_temp_dir(), "other_dir")
843840
# Overwrite the default data_dir (as files get created)
844-
845-
self._original_data_dir = constants.DATA_DIR
846-
constants.DATA_DIR = os.path.join(self.get_temp_dir(), "default_dir")
847-
self.default_data_dir = constants.DATA_DIR
848-
849-
def tearDown(self):
850-
super(DatasetBuilderMultiDirTest, self).tearDown()
851-
# Restore to the default `_registered_data_dir`
852-
file_utils._registered_data_dir = set()
853-
# Clear-up existing dirs
854-
if tf.io.gfile.exists(self.other_data_dir):
855-
tf.io.gfile.rmtree(self.other_data_dir)
856-
if tf.io.gfile.exists(self.default_data_dir):
857-
tf.io.gfile.rmtree(self.default_data_dir)
858-
# Restore the orgininal data dir
859-
constants.DATA_DIR = self._original_data_dir
841+
default_data_dir = self.enter_context(testing.mock_default_data_dir())
842+
# Create a new temp dir
843+
self.other_data_dir = os.path.join(
844+
os.path.dirname(default_data_dir), "other_dir"
845+
)
860846

861847
def test_load_data_dir(self):
862848
"""Ensure that `tfds.load` also supports multiple data_dir."""

tensorflow_datasets/core/dataset_builders/adhoc_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585

8686
from absl import logging
8787
from etils import epath
88+
from tensorflow_datasets.core import constants
8889
from tensorflow_datasets.core import dataset_builder
8990
from tensorflow_datasets.core import dataset_info
9091
from tensorflow_datasets.core import dataset_utils

tensorflow_datasets/core/utils/file_utils.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,23 +156,11 @@ def list_data_dirs(
156156
else:
157157
return [given_data_dir]
158158
else:
159-
default_data_dir = get_default_data_dir(given_data_dir=given_data_dir)
159+
default_data_dir = Path(constants.DATA_DIR)
160160
all_data_dirs = _REGISTERED_DATA_DIRS | {default_data_dir}
161161
return sorted(d.expanduser() for d in all_data_dirs)
162162

163163

164-
def get_default_data_dir(given_data_dir: epath.PathLike | None = None) -> Path:
165-
"""Returns the default data_dir."""
166-
if given_data_dir:
167-
data_dir = os.path.expanduser(given_data_dir)
168-
elif 'TFDS_DATA_DIR' in os.environ:
169-
data_dir = os.environ['TFDS_DATA_DIR']
170-
else:
171-
data_dir = constants.DATA_DIR
172-
173-
return Path(data_dir)
174-
175-
176164
def get_dataset_dir(
177165
data_dir: epath.PathLike,
178166
builder_name: str,
@@ -189,11 +177,11 @@ def get_dataset_dir(
189177

190178

191179
def get_data_dir_and_dataset_dir(
192-
given_data_dir: epath.PathLike | None,
180+
given_data_dir: PathLike | None,
193181
builder_name: str,
194182
config_name: str | None,
195183
version: version_lib.Version | str | None,
196-
) -> tuple[epath.Path, epath.Path]:
184+
) -> tuple[Path, Path]:
197185
"""Returns the data and dataset directories for the given dataset.
198186
199187
Args:
@@ -249,7 +237,7 @@ def get_data_dir_and_dataset_dir(
249237
return next(iter(dataset_dir_by_data_dir.items()))
250238

251239
# No dataset found, use default directory
252-
default_data_dir = get_default_data_dir()
240+
default_data_dir = Path(constants.DATA_DIR)
253241
dataset_dir = get_dataset_dir(
254242
data_dir=default_data_dir,
255243
builder_name=builder_name,

tensorflow_datasets/core/utils/file_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
def test_default_data_dir():
36-
data_dir = file_utils.get_default_data_dir(given_data_dir=None)
36+
data_dir = constants.get_default_data_dir()
3737
assert data_dir
3838

3939

tensorflow_datasets/scripts/cli/cli_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ class PathOptions:
197197
"""
198198

199199
data_dir: epath.Path = simple_parsing.field(
200-
default=epath.Path(constants.DATA_DIR)
200+
default=epath.Path(constants.get_default_data_dir())
201201
)
202202
download_dir: epath.Path | None = None
203203
extract_dir: epath.Path | None = None

tensorflow_datasets/testing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from tensorflow_datasets.testing.test_utils import enable_gcs_access
5151
from tensorflow_datasets.testing.test_utils import fake_examples_dir
5252
from tensorflow_datasets.testing.test_utils import make_tmp_dir
53+
from tensorflow_datasets.testing.test_utils import mock_default_data_dir
5354
from tensorflow_datasets.testing.test_utils import mock_kaggle_api
5455
from tensorflow_datasets.testing.test_utils import MockFs
5556
from tensorflow_datasets.testing.test_utils import rm_tmp_dir
@@ -84,6 +85,7 @@
8485
# TODO(afrozm): rm from here and add as methods to TestCase
8586
"make_tmp_dir": "tensorflow_datasets.testing.test_utils",
8687
"mock_data": "tensorflow_datasets.testing.mocking",
88+
"mock_default_data_dir": "tensorflow_datasets.testing.test_utils",
8789
"mock_kaggle_api": "tensorflow_datasets.testing.test_utils",
8890
"MockFs": "tensorflow_datasets.testing.test_utils",
8991
"MockPolicy": "tensorflow_datasets.testing.mocking",

tensorflow_datasets/testing/test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from etils import epath
3434
from etils import epy
3535
import numpy as np
36+
from tensorflow_datasets.core import constants
3637
from tensorflow_datasets.core import dataset_builder
3738
from tensorflow_datasets.core import dataset_collection_builder
3839
from tensorflow_datasets.core import dataset_info
@@ -884,3 +885,13 @@ def dummy_croissant_file(
884885
croissant_file.write_text(json.dumps(dummy_metadata.to_json(), indent=2))
885886

886887
yield croissant_file
888+
889+
890+
@contextlib.contextmanager
891+
def mock_default_data_dir() -> Iterator[str]:
892+
"""Mocks the `constants.DATA_DIR`."""
893+
with tempfile.TemporaryDirectory() as tempdir:
894+
tmp_data_dir = os.path.join(tempdir, 'default_dir')
895+
os.makedirs(tmp_data_dir)
896+
constants.DATA_DIR = tmp_data_dir
897+
yield tmp_data_dir

0 commit comments

Comments
 (0)