diff --git a/test/test_datasets.py b/test/test_datasets.py index 7e91571744a..1c1d05ac42a 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -532,7 +532,8 @@ def inject_fake_data(self, tmpdir, config): self._create_bbox_txt(base_folder, num_images) self._create_landmarks_txt(base_folder, num_images) - return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names) + num_samples = num_images_per_split.get(config["split"], 0) if isinstance(config["split"], str) else 0 + return dict(num_examples=num_samples, attr_names=attr_names) def _create_split_txt(self, root): num_images_per_split = dict(train=4, valid=3, test=2) @@ -635,6 +636,28 @@ def test_transforms_v2_wrapper_spawn(self): with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _): datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) + def test_invalid_split_list(self): + with pytest.raises(ValueError, match="Expected type str for argument split, but got type ."): + with self.create_dataset(split=[1]): + pass + + def test_invalid_split_int(self): + with pytest.raises(ValueError, match="Expected type str for argument split, but got type ."): + with self.create_dataset(split=1): + pass + + def test_invalid_split_value(self): + with pytest.raises( + ValueError, + match="Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}.".format( + value="invalid", + arg="split", + valid_values=("train", "valid", "test", "all"), + ), + ): + with self.create_dataset(split="invalid"): + pass + class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.VOCSegmentation diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index c15120af5a5..01fb619e778 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -93,7 +93,13 @@ def __init__( "test": 2, "all": None, } - split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))] + split_ = split_map[ + verify_str_arg( + split.lower() if isinstance(split, str) else split, + "split", + ("train", "valid", "test", "all"), + ) + ] splits = self._load_csv("list_eval_partition.txt") identity = self._load_csv("identity_CelebA.txt") bbox = self._load_csv("list_bbox_celeba.txt", header=1)