Skip to content

Commit fa377c4

Browse files
freud14soumith
authored andcommitted
Adding a DatasetFolder class. (#444)
* Adding tests to ImageFolder * Adding DatasetFolder class * Fix tests for pytest and code for lint checker * Adding mock to requirements for ImageFolder tests * Remove mocks from requirements
1 parent 7c052ce commit fa377c4

File tree

11 files changed

+170
-59
lines changed

11 files changed

+170
-59
lines changed

docs/source/datasets.rst

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ torchvision.datasets
44
All datasets are subclasses of :class:`torch.utils.data.Dataset`
55
i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
66
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
7-
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
7+
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
88
For example: ::
9-
9+
1010
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
11-
data_loader = torch.utils.data.DataLoader(imagenet_data,
11+
data_loader = torch.utils.data.DataLoader(imagenet_data,
1212
batch_size=4,
1313
shuffle=True,
1414
num_workers=args.nThreads)
@@ -22,7 +22,7 @@ All the datasets have almost similar API. They all have two common arguments:
2222
``transform`` and ``target_transform`` to transform the input and target respectively.
2323

2424

25-
.. currentmodule:: torchvision.datasets
25+
.. currentmodule:: torchvision.datasets
2626

2727

2828
MNIST
@@ -78,6 +78,14 @@ ImageFolder
7878
:members: __getitem__
7979
:special-members:
8080

81+
DatasetFolder
82+
~~~~~~~~~~~~~
83+
84+
.. autoclass:: DatasetFolder
85+
:members: __getitem__
86+
:special-members:
87+
88+
8189

8290
Imagenet-12
8391
~~~~~~~~~~~
@@ -121,4 +129,3 @@ PhotoTour
121129
.. autoclass:: PhotoTour
122130
:members: __getitem__
123131
:special-members:
124-

test/assets/dataset/a/a1.png

20.4 KB
Loading

test/assets/dataset/a/a2.png

11.4 KB
Loading

test/assets/dataset/a/a3.png

10.9 KB
Loading

test/assets/dataset/b/b1.png

12.9 KB
Loading

test/assets/dataset/b/b2.png

8.79 KB
Loading

test/assets/dataset/b/b3.png

13.4 KB
Loading

test/assets/dataset/b/b4.png

19.2 KB
Loading

test/test_folder.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import unittest
2+
3+
import os
4+
5+
from torchvision.datasets import ImageFolder
6+
7+
8+
def mock_transform(return_value, arg_list):
9+
def mock(arg):
10+
arg_list.append(arg)
11+
return return_value
12+
return mock
13+
14+
15+
class Tester(unittest.TestCase):
16+
root = 'test/assets/dataset/'
17+
classes = ['a', 'b']
18+
class_a_images = [os.path.join('test/assets/dataset/a/', path) for path in ['a1.png', 'a2.png', 'a3.png']]
19+
class_b_images = [os.path.join('test/assets/dataset/b/', path) for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']]
20+
21+
def test_image_folder(self):
22+
dataset = ImageFolder(Tester.root, loader=lambda x: x)
23+
self.assertEqual(sorted(Tester.classes), sorted(dataset.classes))
24+
for cls in Tester.classes:
25+
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
26+
class_a_idx = dataset.class_to_idx['a']
27+
class_b_idx = dataset.class_to_idx['b']
28+
imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images]
29+
imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images]
30+
imgs = sorted(imgs_a + imgs_b)
31+
self.assertEqual(imgs, dataset.imgs)
32+
33+
outputs = sorted([dataset[i] for i in range(len(dataset))])
34+
self.assertEqual(imgs, outputs)
35+
36+
def test_transform(self):
37+
return_value = 'test/assets/dataset/a/a1.png'
38+
39+
args = []
40+
transform = mock_transform(return_value, args)
41+
42+
dataset = ImageFolder(Tester.root, loader=lambda x: x, transform=transform)
43+
outputs = [dataset[i][0] for i in range(len(dataset))]
44+
self.assertEqual([return_value] * len(outputs), outputs)
45+
46+
imgs = sorted(Tester.class_a_images + Tester.class_b_images)
47+
self.assertEqual(imgs, sorted(args))
48+
49+
def test_target_transform(self):
50+
return_value = 1
51+
52+
args = []
53+
target_transform = mock_transform(return_value, args)
54+
55+
dataset = ImageFolder(Tester.root, loader=lambda x: x, target_transform=target_transform)
56+
outputs = [dataset[i][1] for i in range(len(dataset))]
57+
self.assertEqual([return_value] * len(outputs), outputs)
58+
59+
class_a_idx = dataset.class_to_idx['a']
60+
class_b_idx = dataset.class_to_idx['b']
61+
targets = sorted([class_a_idx] * len(Tester.class_a_images) +
62+
[class_b_idx] * len(Tester.class_b_images))
63+
self.assertEqual(targets, sorted(args))
64+
65+
if __name__ == '__main__':
66+
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .lsun import LSUN, LSUNClass
2-
from .folder import ImageFolder
2+
from .folder import ImageFolder, DatasetFolder
33
from .coco import CocoCaptions, CocoDetection
44
from .cifar import CIFAR10, CIFAR100
55
from .stl10 import STL10
@@ -11,7 +11,7 @@
1111
from .omniglot import Omniglot
1212

1313
__all__ = ('LSUN', 'LSUNClass',
14-
'ImageFolder', 'FakeData',
14+
'ImageFolder', 'DatasetFolder', 'FakeData',
1515
'CocoCaptions', 'CocoDetection',
1616
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
1717
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',

0 commit comments

Comments
 (0)