Skip to content

Commit 432aa00

Browse files
chsasankfmassa
authored andcommitted
Improve torchvision documentation (#179)
* Add documentation for transforms * document and remove unused imports in mnist.py * document lsun, mscoco datasets * rest of the datasets documented * Clean up the documentation in other functions * Add links for datasets * Add more documentation * pep8 fix
1 parent fa2836c commit 432aa00

File tree

12 files changed

+379
-58
lines changed

12 files changed

+379
-58
lines changed

torchvision/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@ def set_image_backend(backend):
1111
"""
1212
Specifies the package used to load images.
1313
14-
Options are 'PIL' and 'accimage'. The :mod:`accimage` package uses the
15-
Intel IPP library. It is generally faster than PIL, but does not support as
16-
many operations.
17-
1814
Args:
19-
backend (string): name of the image backend
15+
backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
16+
The :mod:`accimage` package uses the Intel IPP library. It is
17+
generally faster than PIL, but does not support as many operations.
2018
"""
2119
global _image_backend
2220
if backend not in ['PIL', 'accimage']:

torchvision/datasets/cifar.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,22 @@
1515

1616

1717
class CIFAR10(data.Dataset):
18+
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
19+
20+
Args:
21+
root (string): Root directory of dataset where directory
22+
``cifar-10-batches-py`` exists.
23+
train (bool, optional): If True, creates dataset from training set, otherwise
24+
creates from test set.
25+
transform (callable, optional): A function/transform that takes in an PIL image
26+
and returns a transformed version. E.g, ``transforms.RandomCrop``
27+
target_transform (callable, optional): A function/transform that takes in the
28+
target and transforms it.
29+
download (bool, optional): If true, downloads the dataset from the internet and
30+
puts it in root directory. If dataset is already downloaded, it is not
31+
downloaded again.
32+
33+
"""
1834
base_folder = 'cifar-10-batches-py'
1935
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
2036
filename = "cifar-10-python.tar.gz"
@@ -86,6 +102,13 @@ def __init__(self, root, train=True,
86102
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
87103

88104
def __getitem__(self, index):
105+
"""
106+
Args:
107+
index (int): Index
108+
109+
Returns:
110+
tuple: (image, target) where target is index of the target class.
111+
"""
89112
if self.train:
90113
img, target = self.train_data[index], self.train_labels[index]
91114
else:

torchvision/datasets/coco.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,43 @@
55

66

77
class CocoCaptions(data.Dataset):
8+
"""`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
89
10+
Args:
11+
root (string): Root directory where images are downloaded to.
12+
annFile (string): Path to json annotation file.
13+
transform (callable, optional): A function/transform that takes in an PIL image
14+
and returns a transformed version. E.g, ``transforms.ToTensor``
15+
target_transform (callable, optional): A function/transform that takes in the
16+
target and transforms it.
17+
18+
Example:
19+
20+
.. code:: python
21+
22+
import torchvision.datasets as dset
23+
import torchvision.transforms as transforms
24+
cap = dset.CocoCaptions(root = 'dir where images are',
25+
annFile = 'json annotation file',
26+
transform=transforms.ToTensor())
27+
28+
print('Number of samples: ', len(cap))
29+
img, target = cap[3] # load 4th sample
30+
31+
print("Image Size: ", img.size())
32+
print(target)
33+
34+
Output: ::
35+
36+
Number of samples: 82783
37+
Image Size: (3L, 427L, 640L)
38+
[u'A plane emitting smoke stream flying over a mountain.',
39+
u'A plane darts across a bright blue sky behind a mountain covered in snow',
40+
u'A plane leaves a contrail above the snowy mountain top.',
41+
u'A mountain that has a plane flying overheard in the distance.',
42+
u'A mountain view with a plume of smoke in the background']
43+
44+
"""
945
def __init__(self, root, annFile, transform=None, target_transform=None):
1046
from pycocotools.coco import COCO
1147
self.root = root
@@ -15,6 +51,13 @@ def __init__(self, root, annFile, transform=None, target_transform=None):
1551
self.target_transform = target_transform
1652

1753
def __getitem__(self, index):
54+
"""
55+
Args:
56+
index (int): Index
57+
58+
Returns:
59+
tuple: Tuple (image, target). target is a list of captions for the image.
60+
"""
1861
coco = self.coco
1962
img_id = self.ids[index]
2063
ann_ids = coco.getAnnIds(imgIds=img_id)
@@ -37,6 +80,16 @@ def __len__(self):
3780

3881

3982
class CocoDetection(data.Dataset):
83+
"""`MS Coco Captions <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
84+
85+
Args:
86+
root (string): Root directory where images are downloaded to.
87+
annFile (string): Path to json annotation file.
88+
transform (callable, optional): A function/transform that takes in an PIL image
89+
and returns a transformed version. E.g, ``transforms.ToTensor``
90+
target_transform (callable, optional): A function/transform that takes in the
91+
target and transforms it.
92+
"""
4093

4194
def __init__(self, root, annFile, transform=None, target_transform=None):
4295
from pycocotools.coco import COCO
@@ -47,6 +100,13 @@ def __init__(self, root, annFile, transform=None, target_transform=None):
47100
self.target_transform = target_transform
48101

49102
def __getitem__(self, index):
103+
"""
104+
Args:
105+
index (int): Index
106+
107+
Returns:
108+
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
109+
"""
50110
coco = self.coco
51111
img_id = self.ids[index]
52112
ann_ids = coco.getAnnIds(imgIds=img_id)

torchvision/datasets/folder.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,29 @@ def default_loader(path):
6363

6464

6565
class ImageFolder(data.Dataset):
66+
"""A generic data loader where the images are arranged in this way: ::
67+
68+
root/dog/xxx.png
69+
root/dog/xxy.png
70+
root/dog/xxz.png
71+
72+
root/cat/123.png
73+
root/cat/nsdf3.png
74+
root/cat/asd932_.png
75+
76+
Args:
77+
root (string): Root directory path.
78+
transform (callable, optional): A function/transform that takes in an PIL image
79+
and returns a transformed version. E.g, ``transforms.RandomCrop``
80+
target_transform (callable, optional): A function/transform that takes in the
81+
target and transforms it.
82+
loader (callable, optional): A function to load an image given its path.
83+
84+
Attributes:
85+
classes (list): List of the class names.
86+
class_to_idx (dict): Dict with items (class_name, class_index).
87+
imgs (list): List of (image path, class_index) tuples
88+
"""
6689

6790
def __init__(self, root, transform=None, target_transform=None,
6891
loader=default_loader):
@@ -81,6 +104,13 @@ def __init__(self, root, transform=None, target_transform=None,
81104
self.loader = loader
82105

83106
def __getitem__(self, index):
107+
"""
108+
Args:
109+
index (int): Index
110+
111+
Returns:
112+
tuple: (image, target) where target is class_index of the target class.
113+
"""
84114
path, target = self.imgs[index]
85115
img = self.loader(path)
86116
if self.transform is not None:

torchvision/datasets/lsun.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313

1414
class LSUNClass(data.Dataset):
15-
1615
def __init__(self, db_path, transform=None, target_transform=None):
1716
import lmdb
1817
self.db_path = db_path
@@ -58,8 +57,16 @@ def __repr__(self):
5857

5958
class LSUN(data.Dataset):
6059
"""
61-
db_path = root directory for the database files
62-
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]
60+
`LSUN <http://lsun.cs.princeton.edu>`_ dataset.
61+
62+
Args:
63+
db_path (string): Root directory for the database files.
64+
classes (string or list): One of {'train', 'val', 'test'} or a list of
65+
categories to load. e,g. ['bedroom_train', 'church_train'].
66+
transform (callable, optional): A function/transform that takes in an PIL image
67+
and returns a transformed version. E.g, ``transforms.RandomCrop``
68+
target_transform (callable, optional): A function/transform that takes in the
69+
target and transforms it.
6370
"""
6471

6572
def __init__(self, db_path, classes='train',
@@ -108,6 +115,13 @@ def __init__(self, db_path, classes='train',
108115
self.target_transform = target_transform
109116

110117
def __getitem__(self, index):
118+
"""
119+
Args:
120+
index (int): Index
121+
122+
Returns:
123+
tuple: Tuple (image, target) where target is the index of the target category.
124+
"""
111125
target = 0
112126
sub = 0
113127
for ind in self.indices:

torchvision/datasets/mnist.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,25 @@
55
import os.path
66
import errno
77
import torch
8-
import json
98
import codecs
10-
import numpy as np
119

1210

1311
class MNIST(data.Dataset):
12+
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
13+
14+
Args:
15+
root (string): Root directory of dataset where ``processed/training.pt``
16+
and ``processed/test.pt`` exist.
17+
train (bool, optional): If True, creates dataset from ``training.pt``,
18+
otherwise from ``test.pt``.
19+
download (bool, optional): If true, downloads the dataset from the internet and
20+
puts it in root directory. If dataset is already downloaded, it is not
21+
downloaded again.
22+
transform (callable, optional): A function/transform that takes in an PIL image
23+
and returns a transformed version. E.g, ``transforms.RandomCrop``
24+
target_transform (callable, optional): A function/transform that takes in the
25+
target and transforms it.
26+
"""
1427
urls = [
1528
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
1629
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
@@ -42,6 +55,13 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down
4255
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
4356

4457
def __getitem__(self, index):
58+
"""
59+
Args:
60+
index (int): Index
61+
62+
Returns:
63+
tuple: (image, target) where target is index of the target class.
64+
"""
4565
if self.train:
4666
img, target = self.train_data[index], self.train_labels[index]
4767
else:
@@ -70,6 +90,7 @@ def _check_exists(self):
7090
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
7191

7292
def download(self):
93+
"""Download the MNIST data if it doesn't exist in processed_folder already."""
7394
from six.moves import urllib
7495
import gzip
7596

torchvision/datasets/phototour.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,19 @@
1010

1111

1212
class PhotoTour(data.Dataset):
13+
"""`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.
14+
15+
16+
Args:
17+
root (string): Root directory where images are.
18+
name (string): Name of the dataset to load.
19+
transform (callable, optional): A function/transform that takes in an PIL image
20+
and returns a transformed version.
21+
download (bool, optional): If true, downloads the dataset from the internet and
22+
puts it in root directory. If dataset is already downloaded, it is not
23+
downloaded again.
24+
25+
"""
1326
urls = {
1427
'notredame': [
1528
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip',
@@ -59,6 +72,13 @@ def __init__(self, root, name, train=True, transform=None, download=False):
5972
self.data, self.labels, self.matches = torch.load(self.data_file)
6073

6174
def __getitem__(self, index):
75+
"""
76+
Args:
77+
index (int): Index
78+
79+
Returns:
80+
tuple: (data1, data2, matches)
81+
"""
6282
if self.train:
6383
data = self.data[index]
6484
if self.transform is not None:

torchvision/datasets/stl10.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@
1010

1111

1212
class STL10(CIFAR10):
13+
"""`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
14+
15+
Args:
16+
root (string): Root directory of dataset where directory
17+
``stl10_binary`` exists.
18+
split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
19+
Accordingly dataset is selected.
20+
transform (callable, optional): A function/transform that takes in an PIL image
21+
and returns a transformed version. E.g, ``transforms.RandomCrop``
22+
target_transform (callable, optional): A function/transform that takes in the
23+
target and transforms it.
24+
download (bool, optional): If true, downloads the dataset from the internet and
25+
puts it in root directory. If dataset is already downloaded, it is not
26+
downloaded again.
27+
28+
"""
1329
base_folder = 'stl10_binary'
1430
url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
1531
filename = "stl10_binary.tar.gz"
@@ -67,6 +83,13 @@ def __init__(self, root, split='train',
6783
self.classes = f.read().splitlines()
6884

6985
def __getitem__(self, index):
86+
"""
87+
Args:
88+
index (int): Index
89+
90+
Returns:
91+
tuple: (image, target) where target is index of the target class.
92+
"""
7093
if self.labels is not None:
7194
img, target = self.data[index], int(self.labels[index])
7295
else:

torchvision/datasets/svhn.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,27 @@
33
from PIL import Image
44
import os
55
import os.path
6-
import errno
76
import numpy as np
8-
import sys
97
from .utils import download_url, check_integrity
108

119

1210
class SVHN(data.Dataset):
11+
"""`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
12+
13+
Args:
14+
root (string): Root directory of dataset where directory
15+
``SVHN`` exists.
16+
split (string): One of {'train', 'test', 'extra'}.
17+
Accordingly dataset is selected. 'extra' is Extra training set.
18+
transform (callable, optional): A function/transform that takes in an PIL image
19+
and returns a transformed version. E.g, ``transforms.RandomCrop``
20+
target_transform (callable, optional): A function/transform that takes in the
21+
target and transforms it.
22+
download (bool, optional): If true, downloads the dataset from the internet and
23+
puts it in root directory. If dataset is already downloaded, it is not
24+
downloaded again.
25+
26+
"""
1327
url = ""
1428
filename = ""
1529
file_md5 = ""
@@ -56,6 +70,13 @@ def __init__(self, root, split='train',
5670
self.data = np.transpose(self.data, (3, 2, 0, 1))
5771

5872
def __getitem__(self, index):
73+
"""
74+
Args:
75+
index (int): Index
76+
77+
Returns:
78+
tuple: (image, target) where target is index of the target class.
79+
"""
5980
img, target = self.data[index], self.labels[index]
6081

6182
# doing this so that it is consistent with all other datasets

0 commit comments

Comments
 (0)