@@ -532,7 +532,8 @@ def inject_fake_data(self, tmpdir, config):
532532 self ._create_bbox_txt (base_folder , num_images )
533533 self ._create_landmarks_txt (base_folder , num_images )
534534
535- return dict (num_examples = num_images_per_split [config ["split" ]], attr_names = attr_names )
535+ num_samples = num_images_per_split .get (config ["split" ], 0 ) if isinstance (config ["split" ], str ) else 0
536+ return dict (num_examples = num_samples , attr_names = attr_names )
536537
537538 def _create_split_txt (self , root ):
538539 num_images_per_split = dict (train = 4 , valid = 3 , test = 2 )
@@ -635,6 +636,28 @@ def test_transforms_v2_wrapper_spawn(self):
635636 with self .create_dataset (target_type = target_type , transform = v2 .Resize (size = expected_size )) as (dataset , _ ):
636637 datasets_utils .check_transforms_v2_wrapper_spawn (dataset , expected_size = expected_size )
637638
639+ def test_invalid_split_list (self ):
640+ with pytest .raises (ValueError , match = "Expected type str for argument split, but got type <class 'list'>." ):
641+ with self .create_dataset (split = [1 ]):
642+ pass
643+
644+ def test_invalid_split_int (self ):
645+ with pytest .raises (ValueError , match = "Expected type str for argument split, but got type <class 'int'>." ):
646+ with self .create_dataset (split = 1 ):
647+ pass
648+
649+ def test_invalid_split_value (self ):
650+ with pytest .raises (
651+ ValueError ,
652+ match = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." .format (
653+ value = "invalid" ,
654+ arg = "split" ,
655+ valid_values = ("train" , "valid" , "test" , "all" ),
656+ ),
657+ ):
658+ with self .create_dataset (split = "invalid" ):
659+ pass
660+
638661
639662class VOCSegmentationTestCase (datasets_utils .ImageDatasetTestCase ):
640663 DATASET_CLASS = datasets .VOCSegmentation
0 commit comments