Skip to content

Commit ab03dc4

Browse files
freud14soumith
authored andcommitted
Adding a DatasetFolder class. (#442)
* Adding tests to ImageFolder * Adding DatasetFolder class * Fix tests for pytest and code for lint checker * Adding mock to requirements for ImageFolder tests
1 parent 456d3b9 commit ab03dc4

File tree

12 files changed

+164
-59
lines changed

12 files changed

+164
-59
lines changed

docs/source/datasets.rst

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ torchvision.datasets
44
All datasets are subclasses of :class:`torch.utils.data.Dataset`
55
i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
66
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
7-
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
7+
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
88
For example: ::
9-
9+
1010
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
11-
data_loader = torch.utils.data.DataLoader(imagenet_data,
11+
data_loader = torch.utils.data.DataLoader(imagenet_data,
1212
batch_size=4,
1313
shuffle=True,
1414
num_workers=args.nThreads)
@@ -22,7 +22,7 @@ All the datasets have almost similar API. They all have two common arguments:
2222
``transform`` and ``target_transform`` to transform the input and target respectively.
2323

2424

25-
.. currentmodule:: torchvision.datasets
25+
.. currentmodule:: torchvision.datasets
2626

2727

2828
MNIST
@@ -78,6 +78,14 @@ ImageFolder
7878
:members: __getitem__
7979
:special-members:
8080

81+
DatasetFolder
82+
~~~~~~~~~~~~~
83+
84+
.. autoclass:: DatasetFolder
85+
:members: __getitem__
86+
:special-members:
87+
88+
8189

8290
Imagenet-12
8391
~~~~~~~~~~~
@@ -121,4 +129,3 @@ PhotoTour
121129
.. autoclass:: PhotoTour
122130
:members: __getitem__
123131
:special-members:
124-

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def find_version(*file_paths):
3333
'pillow >= 4.1.1',
3434
'six',
3535
'torch',
36+
'mock',
3637
]
3738

3839
setup(

test/assets/dataset/a/a1.png

20.4 KB
Loading

test/assets/dataset/a/a2.png

11.4 KB
Loading

test/assets/dataset/a/a3.png

10.9 KB
Loading

test/assets/dataset/b/b1.png

12.9 KB
Loading

test/assets/dataset/b/b2.png

8.79 KB
Loading

test/assets/dataset/b/b3.png

13.4 KB
Loading

test/assets/dataset/b/b4.png

19.2 KB
Loading

test/test_folder.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import unittest
2+
try:
3+
from unittest.mock import Mock
4+
except ImportError as e:
5+
from mock import Mock
6+
7+
import os
8+
9+
from torchvision.datasets import ImageFolder
10+
11+
12+
class Tester(unittest.TestCase):
13+
root = 'test/assets/dataset/'
14+
classes = ['a', 'b']
15+
class_a_images = [os.path.join('test/assets/dataset/a/', path) for path in ['a1.png', 'a2.png', 'a3.png']]
16+
class_b_images = [os.path.join('test/assets/dataset/b/', path) for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']]
17+
18+
def test_image_folder(self):
19+
dataset = ImageFolder(Tester.root, loader=lambda x: x)
20+
self.assertEqual(sorted(Tester.classes), sorted(dataset.classes))
21+
for cls in Tester.classes:
22+
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
23+
class_a_idx = dataset.class_to_idx['a']
24+
class_b_idx = dataset.class_to_idx['b']
25+
imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images]
26+
imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images]
27+
imgs = sorted(imgs_a + imgs_b)
28+
self.assertEqual(imgs, dataset.imgs)
29+
30+
outputs = sorted([dataset[i] for i in range(len(dataset))])
31+
self.assertEqual(imgs, outputs)
32+
33+
def test_transform(self):
34+
return_value = 'test/assets/dataset/a/a1.png'
35+
transform = Mock(return_value=return_value)
36+
dataset = ImageFolder(Tester.root, loader=lambda x: x, transform=transform)
37+
outputs = [dataset[i][0] for i in range(len(dataset))]
38+
self.assertEqual([return_value] * len(outputs), outputs)
39+
40+
imgs = sorted(Tester.class_a_images + Tester.class_b_images)
41+
args = [call[0][0] for call in transform.call_args_list]
42+
self.assertEqual(imgs, sorted(args))
43+
44+
def test_target_transform(self):
45+
return_value = 1
46+
target_transform = Mock(return_value=return_value)
47+
dataset = ImageFolder(Tester.root, loader=lambda x: x, target_transform=target_transform)
48+
outputs = [dataset[i][1] for i in range(len(dataset))]
49+
self.assertEqual([return_value] * len(outputs), outputs)
50+
51+
class_a_idx = dataset.class_to_idx['a']
52+
class_b_idx = dataset.class_to_idx['b']
53+
targets = sorted([class_a_idx] * len(Tester.class_a_images) +
54+
[class_b_idx] * len(Tester.class_b_images))
55+
args = [call[0][0] for call in target_transform.call_args_list]
56+
self.assertEqual(targets, sorted(args))
57+
58+
if __name__ == '__main__':
59+
unittest.main()

0 commit comments

Comments
 (0)