Skip to content

Commit 73a29e0

Browse files
jaesunysoumith
authored andcommitted
Update LSUN Dataset class (#452)
* Fix uninitialized instance variables * Maintain consistency with other dataset classes * Fix double assignment * Fix initialization of self.classes
1 parent 0036860 commit 73a29e0

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

torchvision/datasets/lsun.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,23 @@
1212

1313

1414
class LSUNClass(data.Dataset):
15-
def __init__(self, db_path, transform=None, target_transform=None):
15+
def __init__(self, root, transform=None, target_transform=None):
1616
import lmdb
17-
self.db_path = db_path
18-
self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False,
17+
self.root = os.path.expanduser(root)
18+
self.transform = transform
19+
self.target_transform = target_transform
20+
21+
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
1922
readahead=False, meminit=False)
2023
with self.env.begin(write=False) as txn:
2124
self.length = txn.stat()['entries']
22-
cache_file = '_cache_' + db_path.replace('/', '_')
25+
cache_file = '_cache_' + root.replace('/', '_')
2326
if os.path.isfile(cache_file):
2427
self.keys = pickle.load(open(cache_file, "rb"))
2528
else:
2629
with self.env.begin(write=False) as txn:
2730
self.keys = [key for key, _ in txn.cursor()]
2831
pickle.dump(self.keys, open(cache_file, "wb"))
29-
self.transform = transform
30-
self.target_transform = target_transform
3132

3233
def __getitem__(self, index):
3334
img, target = None, None
@@ -60,7 +61,7 @@ class LSUN(data.Dataset):
6061
`LSUN <http://lsun.cs.princeton.edu>`_ dataset.
6162
6263
Args:
63-
db_path (string): Root directory for the database files.
64+
root (string): Root directory for the database files.
6465
classes (string or list): One of {'train', 'val', 'test'} or a list of
6566
categories to load. e,g. ['bedroom_train', 'church_train'].
6667
transform (callable, optional): A function/transform that takes in an PIL image
@@ -69,13 +70,16 @@ class LSUN(data.Dataset):
6970
target and transforms it.
7071
"""
7172

72-
def __init__(self, db_path, classes='train',
73+
def __init__(self, root, classes='train',
7374
transform=None, target_transform=None):
7475
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
7576
'conference_room', 'dining_room', 'kitchen',
7677
'living_room', 'restaurant', 'tower']
7778
dset_opts = ['train', 'val', 'test']
78-
self.db_path = db_path
79+
self.root = os.path.expanduser(root)
80+
self.transform = transform
81+
self.target_transform = target_transform
82+
7983
if type(classes) == str and classes in dset_opts:
8084
if classes == 'test':
8185
classes = [classes]
@@ -102,7 +106,7 @@ def __init__(self, db_path, classes='train',
102106
self.dbs = []
103107
for c in self.classes:
104108
self.dbs.append(LSUNClass(
105-
db_path=db_path + '/' + c + '_lmdb',
109+
root=root + '/' + c + '_lmdb',
106110
transform=transform))
107111

108112
self.indices = []
@@ -112,7 +116,6 @@ def __init__(self, db_path, classes='train',
112116
self.indices.append(count)
113117

114118
self.length = count
115-
self.target_transform = target_transform
116119

117120
def __getitem__(self, index):
118121
"""
@@ -146,6 +149,7 @@ def __repr__(self):
146149
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
147150
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
148151
fmt_str += ' Root Location: {}\n'.format(self.root)
152+
fmt_str += ' Classes: {}\n'.format(self.classes)
149153
tmp = ' Transforms (if any): '
150154
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
151155
tmp = ' Target Transforms (if any): '

0 commit comments

Comments
 (0)