|
27 | 27 | import pytest |
28 | 28 | import tensorflow as tf |
29 | 29 | from tensorflow_datasets import testing |
30 | | -from tensorflow_datasets.core import constants |
31 | 30 | from tensorflow_datasets.core import dataset_builder |
32 | 31 | from tensorflow_datasets.core import dataset_info |
33 | 32 | from tensorflow_datasets.core import dataset_utils |
@@ -831,32 +830,19 @@ class DatasetBuilderMultiDirTest(testing.TestCase): |
831 | 830 |
|
832 | 831 | @classmethod |
833 | 832 | def setUpClass(cls): |
834 | | - super(DatasetBuilderMultiDirTest, cls).setUpClass() |
| 833 | + super().setUpClass() |
835 | 834 | cls.builder = DummyDatasetSharedGenerator() |
836 | 835 |
|
837 | 836 | def setUp(self): |
838 | | - super(DatasetBuilderMultiDirTest, self).setUp() |
| 837 | + super().setUp() |
839 | 838 | # Sanity check to make sure that no dir is registered |
840 | 839 | file_utils.clear_registered_data_dirs() |
841 | | - # Create a new temp dir |
842 | | - self.other_data_dir = os.path.join(self.get_temp_dir(), "other_dir") |
843 | 840 | # Overwrite the default data_dir (as files get created) |
844 | | - |
845 | | - self._original_data_dir = constants.DATA_DIR |
846 | | - constants.DATA_DIR = os.path.join(self.get_temp_dir(), "default_dir") |
847 | | - self.default_data_dir = constants.DATA_DIR |
848 | | - |
849 | | - def tearDown(self): |
850 | | - super(DatasetBuilderMultiDirTest, self).tearDown() |
851 | | - # Restore to the default `_registered_data_dir` |
852 | | - file_utils._registered_data_dir = set() |
853 | | - # Clear-up existing dirs |
854 | | - if tf.io.gfile.exists(self.other_data_dir): |
855 | | - tf.io.gfile.rmtree(self.other_data_dir) |
856 | | - if tf.io.gfile.exists(self.default_data_dir): |
857 | | - tf.io.gfile.rmtree(self.default_data_dir) |
858 | | - # Restore the orgininal data dir |
859 | | - constants.DATA_DIR = self._original_data_dir |
| 841 | + default_data_dir = self.enter_context(testing.mock_default_data_dir()) |
| 842 | + # Create a new temp dir |
| 843 | + self.other_data_dir = os.path.join( |
| 844 | + os.path.dirname(default_data_dir), "other_dir" |
| 845 | + ) |
860 | 846 |
|
861 | 847 | def test_load_data_dir(self): |
862 | 848 | """Ensure that `tfds.load` also supports multiple data_dir.""" |
|
0 commit comments