13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
- """Tests for tensorflow_datasets.core.dataset_builder."""
17
-
18
16
from collections .abc import Iterator , Mapping , Sequence
19
17
import dataclasses
20
18
import functools
@@ -353,28 +351,34 @@ def test_build_data_dir(self):
353
351
with testing .tmp_dir (self .get_temp_dir ()) as tmp_dir :
354
352
builder = DummyDatasetSharedGenerator (data_dir = tmp_dir )
355
353
self .assertEqual (str (builder .info .version ), "1.0.0" )
356
- builder_data_dir = os .path .join (tmp_dir , builder .name )
357
- version_dir = os .path .join (builder_data_dir , "1.0.0" )
354
+
355
+ builder_dir = file_utils .get_dataset_dir (
356
+ data_dir = tmp_dir , builder_name = builder .name
357
+ )
358
+ version_dir = builder_dir / "1.0.0"
358
359
359
360
# The dataset folder contains multiple other versions
360
- tf . io . gfile . makedirs ( os . path . join ( builder_data_dir , "14.0.0.invalid" ))
361
- tf . io . gfile . makedirs ( os . path . join ( builder_data_dir , "10.0.0" ))
362
- tf . io . gfile . makedirs ( os . path . join ( builder_data_dir , "9.0.0" ))
363
- tf . io . gfile . makedirs ( os . path . join ( builder_data_dir , "0.1.0" ))
361
+ ( builder_dir / "14.0.0.invalid" ). mkdir ( parents = True , exist_ok = True )
362
+ ( builder_dir / "10.0.0" ). mkdir ( parents = True , exist_ok = True )
363
+ ( builder_dir / "9.0.0" ). mkdir ( parents = True , exist_ok = True )
364
+ ( builder_dir / "0.1.0" ). mkdir ( parents = True , exist_ok = True )
364
365
365
366
# The builder's version dir is chosen
366
- self .assertEqual (builder ._build_data_dir ( tmp_dir )[ 1 ] , version_dir )
367
+ self .assertEqual (builder .data_path , version_dir )
367
368
368
369
def test_get_data_dir_with_config (self ):
369
370
with testing .tmp_dir (self .get_temp_dir ()) as tmp_dir :
370
371
config_name = "plus1"
371
372
builder = DummyDatasetWithConfigs (config = config_name , data_dir = tmp_dir )
372
373
373
- builder_data_dir = os .path .join (tmp_dir , builder .name , config_name )
374
- version_data_dir = os .path .join (builder_data_dir , "0.0.1" )
374
+ version_dir = file_utils .get_dataset_dir (
375
+ data_dir = tmp_dir ,
376
+ builder_name = builder .name ,
377
+ config_name = config_name ,
378
+ version = "0.0.1" ,
379
+ )
375
380
376
- tf .io .gfile .makedirs (version_data_dir )
377
- self .assertEqual (builder ._build_data_dir (tmp_dir )[1 ], version_data_dir )
381
+ self .assertEqual (builder .data_path , version_dir )
378
382
379
383
def test_config_construction (self ):
380
384
with testing .tmp_dir (self .get_temp_dir ()) as tmp_dir :
@@ -775,12 +779,11 @@ class DatasetBuilderMultiDirTest(testing.TestCase):
775
779
def setUpClass (cls ):
776
780
super (DatasetBuilderMultiDirTest , cls ).setUpClass ()
777
781
cls .builder = DummyDatasetSharedGenerator ()
778
- cls .version_dir = os .path .normpath (cls .builder .info .full_name )
779
782
780
783
def setUp (self ):
781
784
super (DatasetBuilderMultiDirTest , self ).setUp ()
782
785
# Sanity check to make sure that no dir is registered
783
- self . assertEmpty ( file_utils ._registered_data_dir )
786
+ file_utils .clear_registered_data_dirs ( )
784
787
# Create a new temp dir
785
788
self .other_data_dir = os .path .join (self .get_temp_dir (), "other_dir" )
786
789
# Overwrite the default data_dir (as files get created)
@@ -801,102 +804,6 @@ def tearDown(self):
801
804
# Restore the orgininal data dir
802
805
constants .DATA_DIR = self ._original_data_dir
803
806
804
- def assertBuildDataDir (self , build_data_dir_out , root_dir ):
805
- data_dir_root , data_dir = build_data_dir_out
806
- self .assertEqual (data_dir_root , root_dir )
807
- self .assertEqual (data_dir , os .path .join (root_dir , self .version_dir ))
808
-
809
- def test_default (self ):
810
- # No data_dir is passed
811
- # -> use default path is used.
812
- self .assertBuildDataDir (
813
- self .builder ._build_data_dir (None ), self .default_data_dir
814
- )
815
-
816
- def test_explicitly_passed (self ):
817
- # When a dir is explicitly passed, use it.
818
- self .assertBuildDataDir (
819
- self .builder ._build_data_dir (self .other_data_dir ), self .other_data_dir
820
- )
821
-
822
- def test_default_multi_dir (self ):
823
- # No data_dir is passed
824
- # Multiple data_dirs are registered
825
- # -> use default path
826
- file_utils .add_data_dir (self .other_data_dir )
827
- self .assertBuildDataDir (
828
- self .builder ._build_data_dir (None ), self .default_data_dir
829
- )
830
-
831
- def test_default_multi_dir_old_version_exists (self ):
832
- # No data_dir is passed
833
- # Multiple data_dirs are registered
834
- # Data dir contains old versions
835
- # -> use default path
836
- file_utils .add_data_dir (self .other_data_dir )
837
- tf .io .gfile .makedirs (
838
- os .path .join (
839
- self .other_data_dir , "dummy_dataset_shared_generator" , "0.1.0"
840
- )
841
- )
842
- tf .io .gfile .makedirs (
843
- os .path .join (
844
- self .other_data_dir , "dummy_dataset_shared_generator" , "0.2.0"
845
- )
846
- )
847
- self .assertBuildDataDir (
848
- self .builder ._build_data_dir (None ), self .default_data_dir
849
- )
850
-
851
- def test_default_multi_dir_version_exists (self ):
852
- # No data_dir is passed
853
- # Multiple data_dirs are registered
854
- # Data found
855
- # -> Re-load existing data
856
- file_utils .add_data_dir (self .other_data_dir )
857
- tf .io .gfile .makedirs (
858
- os .path .join (
859
- self .other_data_dir , "dummy_dataset_shared_generator" , "1.0.0"
860
- )
861
- )
862
- self .assertBuildDataDir (
863
- self .builder ._build_data_dir (None ), self .other_data_dir
864
- )
865
-
866
- def test_default_multi_dir_duplicate (self ):
867
- # If two data dirs contains the dataset, raise an error...
868
- file_utils .add_data_dir (self .other_data_dir )
869
- tf .io .gfile .makedirs (
870
- os .path .join (
871
- self .default_data_dir , "dummy_dataset_shared_generator" , "1.0.0"
872
- )
873
- )
874
- tf .io .gfile .makedirs (
875
- os .path .join (
876
- self .other_data_dir , "dummy_dataset_shared_generator" , "1.0.0"
877
- )
878
- )
879
- with self .assertRaisesRegex (ValueError , "found in more than one directory" ):
880
- self .builder ._build_data_dir (None )
881
-
882
- def test_explicit_multi_dir (self ):
883
- # If two data dirs contains the same version
884
- # Data dir is explicitly passed
885
- file_utils .add_data_dir (self .other_data_dir )
886
- tf .io .gfile .makedirs (
887
- os .path .join (
888
- self .default_data_dir , "dummy_dataset_shared_generator" , "1.0.0"
889
- )
890
- )
891
- tf .io .gfile .makedirs (
892
- os .path .join (
893
- self .other_data_dir , "dummy_dataset_shared_generator" , "1.0.0"
894
- )
895
- )
896
- self .assertBuildDataDir (
897
- self .builder ._build_data_dir (self .other_data_dir ), self .other_data_dir
898
- )
899
-
900
807
def test_load_data_dir (self ):
901
808
"""Ensure that `tfds.load` also supports multiple data_dir."""
902
809
file_utils .add_data_dir (self .other_data_dir )
0 commit comments