5
5
6
6
7
7
class CocoCaptions (data .Dataset ):
8
+ """`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
8
9
10
+ Args:
11
+ root (string): Root directory where images are downloaded to.
12
+ annFile (string): Path to json annotation file.
13
+ transform (callable, optional): A function/transform that takes in an PIL image
14
+ and returns a transformed version. E.g, ``transforms.ToTensor``
15
+ target_transform (callable, optional): A function/transform that takes in the
16
+ target and transforms it.
17
+
18
+ Example:
19
+
20
+ .. code:: python
21
+
22
+ import torchvision.datasets as dset
23
+ import torchvision.transforms as transforms
24
+ cap = dset.CocoCaptions(root = 'dir where images are',
25
+ annFile = 'json annotation file',
26
+ transform=transforms.ToTensor())
27
+
28
+ print('Number of samples: ', len(cap))
29
+ img, target = cap[3] # load 4th sample
30
+
31
+ print("Image Size: ", img.size())
32
+ print(target)
33
+
34
+ Output: ::
35
+
36
+ Number of samples: 82783
37
+ Image Size: (3L, 427L, 640L)
38
+ [u'A plane emitting smoke stream flying over a mountain.',
39
+ u'A plane darts across a bright blue sky behind a mountain covered in snow',
40
+ u'A plane leaves a contrail above the snowy mountain top.',
41
+ u'A mountain that has a plane flying overheard in the distance.',
42
+ u'A mountain view with a plume of smoke in the background']
43
+
44
+ """
9
45
def __init__ (self , root , annFile , transform = None , target_transform = None ):
10
46
from pycocotools .coco import COCO
11
47
self .root = root
@@ -15,6 +51,13 @@ def __init__(self, root, annFile, transform=None, target_transform=None):
15
51
self .target_transform = target_transform
16
52
17
53
def __getitem__ (self , index ):
54
+ """
55
+ Args:
56
+ index (int): Index
57
+
58
+ Returns:
59
+ tuple: Tuple (image, target). target is a list of captions for the image.
60
+ """
18
61
coco = self .coco
19
62
img_id = self .ids [index ]
20
63
ann_ids = coco .getAnnIds (imgIds = img_id )
@@ -37,6 +80,16 @@ def __len__(self):
37
80
38
81
39
82
class CocoDetection (data .Dataset ):
83
+ """`MS Coco Captions <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
84
+
85
+ Args:
86
+ root (string): Root directory where images are downloaded to.
87
+ annFile (string): Path to json annotation file.
88
+ transform (callable, optional): A function/transform that takes in an PIL image
89
+ and returns a transformed version. E.g, ``transforms.ToTensor``
90
+ target_transform (callable, optional): A function/transform that takes in the
91
+ target and transforms it.
92
+ """
40
93
41
94
def __init__ (self , root , annFile , transform = None , target_transform = None ):
42
95
from pycocotools .coco import COCO
@@ -47,6 +100,13 @@ def __init__(self, root, annFile, transform=None, target_transform=None):
47
100
self .target_transform = target_transform
48
101
49
102
def __getitem__ (self , index ):
103
+ """
104
+ Args:
105
+ index (int): Index
106
+
107
+ Returns:
108
+ tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
109
+ """
50
110
coco = self .coco
51
111
img_id = self .ids [index ]
52
112
ann_ids = coco .getAnnIds (imgIds = img_id )
0 commit comments