Skip to content

Commit 3b6c670

Browse files
reynoldscemsoumith
authored andcommitted
Fix for issue #447 - STL dataset returns test fold if fold is misspecified (#449)
1 parent c76ac7f commit 3b6c670

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

torchvision/datasets/stl10.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,14 @@ class STL10(CIFAR10):
4141
['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'],
4242
['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e']
4343
]
44+
splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
4445

4546
def __init__(self, root, split='train',
4647
transform=None, target_transform=None, download=False):
48+
if split not in self.splits:
49+
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
50+
split, ', '.join(self.splits),
51+
))
4752
self.root = os.path.expanduser(root)
4853
self.transform = transform
4954
self.target_transform = target_transform

0 commit comments

Comments
 (0)