We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c76ac7f commit 3b6c670Copy full SHA for 3b6c670
torchvision/datasets/stl10.py
@@ -41,9 +41,14 @@ class STL10(CIFAR10):
41
['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'],
42
['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e']
43
]
44
+ splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
45
46
def __init__(self, root, split='train',
47
transform=None, target_transform=None, download=False):
48
+ if split not in self.splits:
49
+ raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
50
+ split, ', '.join(self.splits),
51
+ ))
52
self.root = os.path.expanduser(root)
53
self.transform = transform
54
self.target_transform = target_transform
0 commit comments