12
12
13
13
14
14
class LSUNClass (data .Dataset ):
15
- def __init__ (self , db_path , transform = None , target_transform = None ):
15
+ def __init__ (self , root , transform = None , target_transform = None ):
16
16
import lmdb
17
- self .db_path = db_path
18
- self .env = lmdb .open (db_path , max_readers = 1 , readonly = True , lock = False ,
17
+ self .root = os .path .expanduser (root )
18
+ self .transform = transform
19
+ self .target_transform = target_transform
20
+
21
+ self .env = lmdb .open (root , max_readers = 1 , readonly = True , lock = False ,
19
22
readahead = False , meminit = False )
20
23
with self .env .begin (write = False ) as txn :
21
24
self .length = txn .stat ()['entries' ]
22
- cache_file = '_cache_' + db_path .replace ('/' , '_' )
25
+ cache_file = '_cache_' + root .replace ('/' , '_' )
23
26
if os .path .isfile (cache_file ):
24
27
self .keys = pickle .load (open (cache_file , "rb" ))
25
28
else :
26
29
with self .env .begin (write = False ) as txn :
27
30
self .keys = [key for key , _ in txn .cursor ()]
28
31
pickle .dump (self .keys , open (cache_file , "wb" ))
29
- self .transform = transform
30
- self .target_transform = target_transform
31
32
32
33
def __getitem__ (self , index ):
33
34
img , target = None , None
@@ -60,7 +61,7 @@ class LSUN(data.Dataset):
60
61
`LSUN <http://lsun.cs.princeton.edu>`_ dataset.
61
62
62
63
Args:
63
- db_path (string): Root directory for the database files.
64
+ root (string): Root directory for the database files.
64
65
classes (string or list): One of {'train', 'val', 'test'} or a list of
65
66
categories to load. e,g. ['bedroom_train', 'church_train'].
66
67
transform (callable, optional): A function/transform that takes in an PIL image
@@ -69,13 +70,16 @@ class LSUN(data.Dataset):
69
70
target and transforms it.
70
71
"""
71
72
72
- def __init__ (self , db_path , classes = 'train' ,
73
+ def __init__ (self , root , classes = 'train' ,
73
74
transform = None , target_transform = None ):
74
75
categories = ['bedroom' , 'bridge' , 'church_outdoor' , 'classroom' ,
75
76
'conference_room' , 'dining_room' , 'kitchen' ,
76
77
'living_room' , 'restaurant' , 'tower' ]
77
78
dset_opts = ['train' , 'val' , 'test' ]
78
- self .db_path = db_path
79
+ self .root = os .path .expanduser (root )
80
+ self .transform = transform
81
+ self .target_transform = target_transform
82
+
79
83
if type (classes ) == str and classes in dset_opts :
80
84
if classes == 'test' :
81
85
classes = [classes ]
@@ -102,7 +106,7 @@ def __init__(self, db_path, classes='train',
102
106
self .dbs = []
103
107
for c in self .classes :
104
108
self .dbs .append (LSUNClass (
105
- db_path = db_path + '/' + c + '_lmdb' ,
109
+ root = root + '/' + c + '_lmdb' ,
106
110
transform = transform ))
107
111
108
112
self .indices = []
@@ -112,7 +116,6 @@ def __init__(self, db_path, classes='train',
112
116
self .indices .append (count )
113
117
114
118
self .length = count
115
- self .target_transform = target_transform
116
119
117
120
def __getitem__ (self , index ):
118
121
"""
@@ -146,6 +149,7 @@ def __repr__(self):
146
149
fmt_str = 'Dataset ' + self .__class__ .__name__ + '\n '
147
150
fmt_str += ' Number of datapoints: {}\n ' .format (self .__len__ ())
148
151
fmt_str += ' Root Location: {}\n ' .format (self .root )
152
+ fmt_str += ' Classes: {}\n ' .format (self .classes )
149
153
tmp = ' Transforms (if any): '
150
154
fmt_str += '{0}{1}\n ' .format (tmp , self .transform .__repr__ ().replace ('\n ' , '\n ' + ' ' * len (tmp )))
151
155
tmp = ' Target Transforms (if any): '
0 commit comments