Skip to content

Commit 59c97d7

Browse files
Philip Meierfmassa
authored andcommitted
Miscellaneous dataset fixes (#1174)
* fix stl10 * fix lsun
1 parent 8102158 commit 59c97d7

File tree

2 files changed

+52
-35
lines changed

2 files changed

+52
-35
lines changed

torchvision/datasets/lsun.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import six
66
import string
77
import sys
8+
from collections import Iterable
89

910
if sys.version_info[0] == 2:
1011
import cPickle as pickle
@@ -72,6 +73,24 @@ class LSUN(VisionDataset):
7273
def __init__(self, root, classes='train', transform=None, target_transform=None):
7374
super(LSUN, self).__init__(root, transform=transform,
7475
target_transform=target_transform)
76+
self.classes = self._verify_classes(classes)
77+
78+
# for each class, create an LSUNClassDataset
79+
self.dbs = []
80+
for c in self.classes:
81+
self.dbs.append(LSUNClass(
82+
root=root + '/' + c + '_lmdb',
83+
transform=transform))
84+
85+
self.indices = []
86+
count = 0
87+
for db in self.dbs:
88+
count += len(db)
89+
self.indices.append(count)
90+
91+
self.length = count
92+
93+
def _verify_classes(self, classes):
7594
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
7695
'conference_room', 'dining_room', 'kitchen',
7796
'living_room', 'restaurant', 'tower']
@@ -84,39 +103,28 @@ def __init__(self, root, classes='train', transform=None, target_transform=None)
84103
else:
85104
classes = [c + '_' + classes for c in categories]
86105
except ValueError:
87-
# TODO: Should this check for Iterable instead of list?
88-
if not isinstance(classes, list):
89-
raise ValueError
106+
if not isinstance(classes, Iterable):
107+
msg = ("Expected type str or Iterable for argument classes, "
108+
"but got type {}.")
109+
raise ValueError(msg.format(type(classes)))
110+
111+
classes = list(classes)
112+
msg_fmtstr = ("Expected type str for elements in argument classes, "
113+
"but got type {}.")
90114
for c in classes:
91-
# TODO: This assumes each item is a str (or subclass). Should this
92-
# also be checked?
115+
verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
93116
c_short = c.split('_')
94117
category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]
95-
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
96118

119+
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
97120
msg = msg_fmtstr.format(category, "LSUN class",
98121
iterable_to_str(categories))
99122
verify_str_arg(category, valid_values=categories, custom_msg=msg)
100123

101124
msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
102125
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
103-
finally:
104-
self.classes = classes
105-
106-
# for each class, create an LSUNClassDataset
107-
self.dbs = []
108-
for c in self.classes:
109-
self.dbs.append(LSUNClass(
110-
root=root + '/' + c + '_lmdb',
111-
transform=transform))
112-
113-
self.indices = []
114-
count = 0
115-
for db in self.dbs:
116-
count += len(db)
117-
self.indices.append(count)
118126

119-
self.length = count
127+
return classes
120128

121129
def __getitem__(self, index):
122130
"""

torchvision/datasets/stl10.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, root, split='train', folds=None, transform=None,
5151
super(STL10, self).__init__(root, transform=transform,
5252
target_transform=target_transform)
5353
self.split = verify_str_arg(split, "split", self.splits)
54-
self.folds = folds # one of the 10 pre-defined folds or the full dataset
54+
self.folds = self._verify_folds(folds)
5555

5656
if download:
5757
self.download()
@@ -89,6 +89,19 @@ def __init__(self, root, split='train', folds=None, transform=None,
8989
with open(class_file) as f:
9090
self.classes = f.read().splitlines()
9191

92+
def _verify_folds(self, folds):
93+
if folds is None:
94+
return folds
95+
elif isinstance(folds, int):
96+
if folds in range(10):
97+
return folds
98+
msg = ("Value for argument folds should be in the range [0, 10), "
99+
"but got {}.")
100+
raise ValueError(msg.format(folds))
101+
else:
102+
msg = "Expected type None or int for argument folds, but got type {}."
103+
raise ValueError(msg.format(type(folds)))
104+
92105
def __getitem__(self, index):
93106
"""
94107
Args:
@@ -154,15 +167,11 @@ def extra_repr(self):
154167

155168
def __load_folds(self, folds):
156169
# loads one of the folds if specified
157-
if isinstance(folds, int):
158-
if folds >= 0 and folds < 10:
159-
path_to_folds = os.path.join(
160-
self.root, self.base_folder, self.folds_list_file)
161-
with open(path_to_folds, 'r') as f:
162-
str_idx = f.read().splitlines()[folds]
163-
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
164-
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
165-
else:
166-
# FIXME: docstring allows None for folds (it is even the default value)
167-
# Is this intended?
168-
raise ValueError('Folds "{}" not found. Valid splits are: 0-9.'.format(folds))
170+
if folds is None:
171+
return
172+
path_to_folds = os.path.join(
173+
self.root, self.base_folder, self.folds_list_file)
174+
with open(path_to_folds, 'r') as f:
175+
str_idx = f.read().splitlines()[folds]
176+
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
177+
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]

0 commit comments

Comments
 (0)