Skip to content

Commit 1b9304c

Browse files
committed
lsun test classes fix
1 parent d8cb7f0 commit 1b9304c

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchvision/datasets/lsun.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def __init__(self, db_path, classes='train',
7070
dset_opts = ['train', 'val', 'test']
7171
self.db_path = db_path
7272
if type(classes) == str and classes in dset_opts:
73-
classes = [c + '_' + classes for c in categories]
73+
if classes == 'test':
74+
classes = [classes]
75+
else:
76+
classes = [c + '_' + classes for c in categories]
7477
if type(classes) == list:
7578
for c in classes:
7679
c_short = c.split('_')

0 commit comments

Comments
 (0)