5
5
import six
6
6
import string
7
7
import sys
8
+ from collections import Iterable
8
9
9
10
if sys .version_info [0 ] == 2 :
10
11
import cPickle as pickle
@@ -72,6 +73,24 @@ class LSUN(VisionDataset):
72
73
def __init__ (self , root , classes = 'train' , transform = None , target_transform = None ):
73
74
super (LSUN , self ).__init__ (root , transform = transform ,
74
75
target_transform = target_transform )
76
+ self .classes = self ._verify_classes (classes )
77
+
78
+ # for each class, create an LSUNClassDataset
79
+ self .dbs = []
80
+ for c in self .classes :
81
+ self .dbs .append (LSUNClass (
82
+ root = root + '/' + c + '_lmdb' ,
83
+ transform = transform ))
84
+
85
+ self .indices = []
86
+ count = 0
87
+ for db in self .dbs :
88
+ count += len (db )
89
+ self .indices .append (count )
90
+
91
+ self .length = count
92
+
93
+ def _verify_classes (self , classes ):
75
94
categories = ['bedroom' , 'bridge' , 'church_outdoor' , 'classroom' ,
76
95
'conference_room' , 'dining_room' , 'kitchen' ,
77
96
'living_room' , 'restaurant' , 'tower' ]
@@ -84,39 +103,28 @@ def __init__(self, root, classes='train', transform=None, target_transform=None)
84
103
else :
85
104
classes = [c + '_' + classes for c in categories ]
86
105
except ValueError :
87
- # TODO: Should this check for Iterable instead of list?
88
- if not isinstance (classes , list ):
89
- raise ValueError
106
+ if not isinstance (classes , Iterable ):
107
+ msg = ("Expected type str or Iterable for argument classes, "
108
+ "but got type {}." )
109
+ raise ValueError (msg .format (type (classes )))
110
+
111
+ classes = list (classes )
112
+ msg_fmtstr = ("Expected type str for elements in argument classes, "
113
+ "but got type {}." )
90
114
for c in classes :
91
- # TODO: This assumes each item is a str (or subclass). Should this
92
- # also be checked?
115
+ verify_str_arg (c , custom_msg = msg_fmtstr .format (type (c )))
93
116
c_short = c .split ('_' )
94
117
category , dset_opt = '_' .join (c_short [:- 1 ]), c_short [- 1 ]
95
- msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
96
118
119
+ msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
97
120
msg = msg_fmtstr .format (category , "LSUN class" ,
98
121
iterable_to_str (categories ))
99
122
verify_str_arg (category , valid_values = categories , custom_msg = msg )
100
123
101
124
msg = msg_fmtstr .format (dset_opt , "postfix" , iterable_to_str (dset_opts ))
102
125
verify_str_arg (dset_opt , valid_values = dset_opts , custom_msg = msg )
103
- finally :
104
- self .classes = classes
105
-
106
- # for each class, create an LSUNClassDataset
107
- self .dbs = []
108
- for c in self .classes :
109
- self .dbs .append (LSUNClass (
110
- root = root + '/' + c + '_lmdb' ,
111
- transform = transform ))
112
-
113
- self .indices = []
114
- count = 0
115
- for db in self .dbs :
116
- count += len (db )
117
- self .indices .append (count )
118
126
119
- self . length = count
127
+ return classes
120
128
121
129
def __getitem__ (self , index ):
122
130
"""
0 commit comments