Skip to content

Commit 4886ccc

Browse files
Philip Meierfmassa
authored andcommitted
Standardize str argument verification in datasets (#1167)
* introduced function to verify str arguments * flake8 * added FIXME to VOC * Fixed error message * added test for verify_str_arg * cleanup todos * added option for custom error message * fix VOC * fixed Caltech
1 parent d9830d8 commit 4886ccc

File tree

12 files changed

+106
-110
lines changed

12 files changed

+106
-110
lines changed

test/test_datasets_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def test_extract_gzip(self):
107107
data = nf.read()
108108
self.assertEqual(data, 'this is the content')
109109

110+
def test_verify_str_arg(self):
111+
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
112+
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
113+
self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
114+
110115

111116
if __name__ == '__main__':
112117
unittest.main()

torchvision/datasets/caltech.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os.path
55

66
from .vision import VisionDataset
7-
from .utils import download_and_extract_archive, makedir_exist_ok
7+
from .utils import download_and_extract_archive, makedir_exist_ok, verify_str_arg
88

99

1010
class Caltech101(VisionDataset):
@@ -32,10 +32,10 @@ def __init__(self, root, target_type="category", transform=None,
3232
transform=transform,
3333
target_transform=target_transform)
3434
makedir_exist_ok(self.root)
35-
if isinstance(target_type, list):
36-
self.target_type = target_type
37-
else:
38-
self.target_type = [target_type]
35+
if not isinstance(target_type, list):
36+
target_type = [target_type]
37+
self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation"))
38+
for t in target_type]
3939

4040
if download:
4141
self.download()
@@ -88,8 +88,6 @@ def __getitem__(self, index):
8888
self.annotation_categories[self.y[index]],
8989
"annotation_{:04d}.mat".format(self.index[index])))
9090
target.append(data["obj_contour"])
91-
else:
92-
raise ValueError("Target type \"{}\" is not recognized.".format(t))
9391
target = tuple(target) if len(target) > 1 else target[0]
9492

9593
if self.transform is not None:

torchvision/datasets/celeba.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import PIL
55
from .vision import VisionDataset
6-
from .utils import download_file_from_google_drive, check_integrity
6+
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
77

88

99
class CelebA(VisionDataset):
@@ -66,17 +66,14 @@ def __init__(self, root, split="train", target_type="attr", transform=None,
6666
raise RuntimeError('Dataset not found or corrupted.' +
6767
' You can use download=True to download it')
6868

69-
if split.lower() == "train":
70-
split = 0
71-
elif split.lower() == "valid":
72-
split = 1
73-
elif split.lower() == "test":
74-
split = 2
75-
elif split.lower() == "all":
76-
split = None
77-
else:
78-
raise ValueError('Wrong split entered! Please use "train", '
79-
'"valid", "test", or "all"')
69+
split_map = {
70+
"train": 0,
71+
"valid": 1,
72+
"test": 2,
73+
"all": None,
74+
}
75+
split = split_map[verify_str_arg(split.lower(), "split",
76+
("train", "valid", "test", "all"))]
8077

8178
fn = partial(os.path.join, self.root, self.base_folder)
8279
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
@@ -134,6 +131,7 @@ def __getitem__(self, index):
134131
elif t == "landmarks":
135132
target.append(self.landmarks_align[index, :])
136133
else:
134+
# TODO: refactor with utils.verify_str_arg
137135
raise ValueError("Target type \"{}\" is not recognized.".format(t))
138136
target = tuple(target) if len(target) > 1 else target[0]
139137

torchvision/datasets/cityscapes.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import namedtuple
44
import zipfile
55

6-
from .utils import extract_archive
6+
from .utils import extract_archive, verify_str_arg, iterable_to_str
77
from .vision import VisionDataset
88
from PIL import Image
99

@@ -109,22 +109,21 @@ def __init__(self, root, split='train', mode='fine', target_type='instance',
109109
self.images = []
110110
self.targets = []
111111

112-
if mode not in ['fine', 'coarse']:
113-
raise ValueError('Invalid mode! Please use mode="fine" or mode="coarse"')
114-
115-
if mode == 'fine' and split not in ['train', 'test', 'val']:
116-
raise ValueError('Invalid split for mode "fine"! Please use split="train", split="test"'
117-
' or split="val"')
118-
elif mode == 'coarse' and split not in ['train', 'train_extra', 'val']:
119-
raise ValueError('Invalid split for mode "coarse"! Please use split="train", split="train_extra"'
120-
' or split="val"')
112+
verify_str_arg(mode, "mode", ("fine", "coarse"))
113+
if mode == "fine":
114+
valid_modes = ("train", "test", "val")
115+
else:
116+
valid_modes = ("train", "train_extra", "val")
117+
msg = ("Unknown value '{}' for argument split if mode is '{}'. "
118+
"Valid values are {{{}}}.")
119+
msg = msg.format(split, mode, iterable_to_str(valid_modes))
120+
verify_str_arg(split, "split", valid_modes, msg)
121121

122122
if not isinstance(target_type, list):
123123
self.target_type = [target_type]
124-
125-
if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type):
126-
raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
127-
' or "color"')
124+
[verify_str_arg(value, "target_type",
125+
("instance", "semantic", "polygon", "color"))
126+
for value in self.target_type]
128127

129128
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
130129

torchvision/datasets/imagenet.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import tempfile
55
import torch
66
from .folder import ImageFolder
7-
from .utils import check_integrity, download_and_extract_archive, extract_archive
7+
from .utils import check_integrity, download_and_extract_archive, extract_archive, \
8+
verify_str_arg
89

910
ARCHIVE_DICT = {
1011
'train': {
@@ -48,7 +49,7 @@ class ImageNet(ImageFolder):
4849

4950
def __init__(self, root, split='train', download=False, **kwargs):
5051
root = self.root = os.path.expanduser(root)
51-
self.split = self._verify_split(split)
52+
self.split = verify_str_arg(split, "split", ("train", "val"))
5253

5354
if download:
5455
self.download()
@@ -109,17 +110,6 @@ def _load_meta_file(self):
109110
def _save_meta_file(self, wnid_to_class, val_wnids):
110111
torch.save((wnid_to_class, val_wnids), self.meta_file)
111112

112-
def _verify_split(self, split):
113-
if split not in self.valid_splits:
114-
msg = "Unknown split {} .".format(split)
115-
msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
116-
raise ValueError(msg)
117-
return split
118-
119-
@property
120-
def valid_splits(self):
121-
return 'train', 'val'
122-
123113
@property
124114
def split_folder(self):
125115
return os.path.join(self.root, self.split)

torchvision/datasets/lsun.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
else:
1212
import pickle
1313

14+
from .utils import verify_str_arg, iterable_to_str
15+
1416

1517
class LSUNClass(VisionDataset):
1618
def __init__(self, root, transform=None, target_transform=None):
@@ -75,27 +77,31 @@ def __init__(self, root, classes='train', transform=None, target_transform=None)
7577
'living_room', 'restaurant', 'tower']
7678
dset_opts = ['train', 'val', 'test']
7779

78-
if type(classes) == str and classes in dset_opts:
80+
try:
81+
verify_str_arg(classes, "classes", dset_opts)
7982
if classes == 'test':
8083
classes = [classes]
8184
else:
8285
classes = [c + '_' + classes for c in categories]
83-
elif type(classes) == list:
86+
except ValueError:
87+
# TODO: Should this check for Iterable instead of list?
88+
if not isinstance(classes, list):
89+
raise ValueError
8490
for c in classes:
91+
# TODO: This assumes each item is a str (or subclass). Should this
92+
# also be checked?
8593
c_short = c.split('_')
86-
c_short.pop(len(c_short) - 1)
87-
c_short = '_'.join(c_short)
88-
if c_short not in categories:
89-
raise (ValueError('Unknown LSUN class: ' + c_short + '.'
90-
'Options are: ' + str(categories)))
91-
c_short = c.split('_')
92-
c_short = c_short.pop(len(c_short) - 1)
93-
if c_short not in dset_opts:
94-
raise (ValueError('Unknown postfix: ' + c_short + '.'
95-
'Options are: ' + str(dset_opts)))
96-
else:
97-
raise (ValueError('Unknown option for classes'))
98-
self.classes = classes
94+
category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]
95+
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
96+
97+
msg = msg_fmtstr.format(category, "LSUN class",
98+
iterable_to_str(categories))
99+
verify_str_arg(category, valid_values=categories, custom_msg=msg)
100+
101+
msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
102+
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
103+
finally:
104+
self.classes = classes
99105

100106
# for each class, create an LSUNClassDataset
101107
self.dbs = []

torchvision/datasets/mnist.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import numpy as np
88
import torch
99
import codecs
10-
from .utils import download_url, download_and_extract_archive, extract_archive, makedir_exist_ok
10+
from .utils import download_url, download_and_extract_archive, extract_archive, \
11+
makedir_exist_ok, verify_str_arg
1112

1213

1314
class MNIST(VisionDataset):
@@ -230,11 +231,7 @@ class EMNIST(MNIST):
230231
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
231232

232233
def __init__(self, root, split, **kwargs):
233-
if split not in self.splits:
234-
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
235-
split, ', '.join(self.splits),
236-
))
237-
self.split = split
234+
self.split = verify_str_arg(split, "split", self.splits)
238235
self.training_file = self._training_file(split)
239236
self.test_file = self._test_file(split)
240237
super(EMNIST, self).__init__(root, **kwargs)
@@ -336,10 +333,7 @@ class QMNIST(MNIST):
336333
def __init__(self, root, what=None, compat=True, train=True, **kwargs):
337334
if what is None:
338335
what = 'train' if train else 'test'
339-
if not self.subsets.get(what):
340-
raise RuntimeError("Argument 'what' should be one of: \n " +
341-
repr(tuple(self.subsets.keys())))
342-
self.what = what
336+
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
343337
self.compat = compat
344338
self.data_file = what + '.pt'
345339
self.training_file = self.data_file

torchvision/datasets/sbd.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from PIL import Image
8-
from .utils import download_url
8+
from .utils import download_url, verify_str_arg
99
from .voc import download_extract
1010

1111

@@ -64,12 +64,9 @@ def __init__(self,
6464
"pip install scipy")
6565

6666
super(SBDataset, self).__init__(root, transforms)
67-
68-
if mode not in ("segmentation", "boundaries"):
69-
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")
70-
71-
self.image_set = image_set
72-
self.mode = mode
67+
self.image_set = verify_str_arg(image_set, "image_set",
68+
("train", "val", "train_noval"))
69+
self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
7370
self.num_classes = 20
7471

7572
sbd_root = self.root
@@ -91,11 +88,6 @@ def __init__(self,
9188

9289
split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')
9390

94-
if not os.path.exists(split_f):
95-
raise ValueError(
96-
'Wrong image_set entered! Please use image_set="train" '
97-
'or image_set="val" or image_set="train_noval"')
98-
9991
with open(os.path.join(split_f), "r") as f:
10092
file_names = [x.strip() for x in f.readlines()]
10193

torchvision/datasets/stl10.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from .vision import VisionDataset
8-
from .utils import check_integrity, download_and_extract_archive
8+
from .utils import check_integrity, download_and_extract_archive, verify_str_arg
99

1010

1111
class STL10(VisionDataset):
@@ -48,13 +48,9 @@ class STL10(VisionDataset):
4848

4949
def __init__(self, root, split='train', folds=None, transform=None,
5050
target_transform=None, download=False):
51-
if split not in self.splits:
52-
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
53-
split, ', '.join(self.splits),
54-
))
5551
super(STL10, self).__init__(root, transform=transform,
5652
target_transform=target_transform)
57-
self.split = split # train/test/unlabeled set
53+
self.split = verify_str_arg(split, "split", self.splits)
5854
self.folds = folds # one of the 10 pre-defined folds or the full dataset
5955

6056
if download:
@@ -167,4 +163,6 @@ def __load_folds(self, folds):
167163
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
168164
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
169165
else:
166+
# FIXME: docstring allows None for folds (it is even the default value)
167+
# Is this intended?
170168
raise ValueError('Folds "{}" not found. Valid splits are: 0-9.'.format(folds))

torchvision/datasets/svhn.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import os.path
66
import numpy as np
7-
from .utils import download_url, check_integrity
7+
from .utils import download_url, check_integrity, verify_str_arg
88

99

1010
class SVHN(VisionDataset):
@@ -43,12 +43,7 @@ def __init__(self, root, split='train', transform=None, target_transform=None,
4343
download=False):
4444
super(SVHN, self).__init__(root, transform=transform,
4545
target_transform=target_transform)
46-
self.split = split # training set or test set or extra set
47-
48-
if self.split not in self.split_list:
49-
raise ValueError('Wrong split entered! Please use split="train" '
50-
'or split="extra" or split="test"')
51-
46+
self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
5247
self.url = self.split_list[split][0]
5348
self.filename = self.split_list[split][1]
5449
self.file_md5 = self.split_list[split][2]

0 commit comments

Comments
 (0)