@@ -532,7 +532,8 @@ def inject_fake_data(self, tmpdir, config):
532532        self ._create_bbox_txt (base_folder , num_images )
533533        self ._create_landmarks_txt (base_folder , num_images )
534534
535-         return  dict (num_examples = num_images_per_split [config ["split" ]], attr_names = attr_names )
535+         num_samples  =  num_images_per_split .get (config ["split" ], 0 ) if  isinstance (config ["split" ], str ) else  0 
536+         return  dict (num_examples = num_samples , attr_names = attr_names )
536537
537538    def  _create_split_txt (self , root ):
538539        num_images_per_split  =  dict (train = 4 , valid = 3 , test = 2 )
@@ -635,6 +636,28 @@ def test_transforms_v2_wrapper_spawn(self):
635636            with  self .create_dataset (target_type = target_type , transform = v2 .Resize (size = expected_size )) as  (dataset , _ ):
636637                datasets_utils .check_transforms_v2_wrapper_spawn (dataset , expected_size = expected_size )
637638
639+     def  test_invalid_split_list (self ):
640+         with  pytest .raises (ValueError , match = "Expected type str for argument split, but got type <class 'list'>." ):
641+             with  self .create_dataset (split = [1 ]):
642+                 pass 
643+ 
644+     def  test_invalid_split_int (self ):
645+         with  pytest .raises (ValueError , match = "Expected type str for argument split, but got type <class 'int'>." ):
646+             with  self .create_dataset (split = 1 ):
647+                 pass 
648+ 
649+     def  test_invalid_split_value (self ):
650+         with  pytest .raises (
651+             ValueError ,
652+             match = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." .format (
653+                 value = "invalid" ,
654+                 arg = "split" ,
655+                 valid_values = ("train" , "valid" , "test" , "all" ),
656+             ),
657+         ):
658+             with  self .create_dataset (split = "invalid" ):
659+                 pass 
660+ 
638661
639662class  VOCSegmentationTestCase (datasets_utils .ImageDatasetTestCase ):
640663    DATASET_CLASS  =  datasets .VOCSegmentation 
0 commit comments