2
2
from PIL import Image
3
3
import os
4
4
import os .path
5
- from typing import Any , Callable , Optional , Tuple
5
+ from typing import Any , Callable , Optional , Tuple , List
6
6
7
7
8
- class CocoCaptions (VisionDataset ):
9
- """`MS Coco Captions <https://cocodataset.org/#captions-2015 >`_ Dataset.
8
+ class CocoDetection (VisionDataset ):
9
+ """`MS Coco Detection <https://cocodataset.org/#detection-2016 >`_ Dataset.
10
10
11
11
Args:
12
12
root (string): Root directory where images are downloaded to.
@@ -17,77 +17,45 @@ class CocoCaptions(VisionDataset):
17
17
target and transforms it.
18
18
transforms (callable, optional): A function/transform that takes input sample and its target as entry
19
19
and returns a transformed version.
20
-
21
- Example:
22
-
23
- .. code:: python
24
-
25
- import torchvision.datasets as dset
26
- import torchvision.transforms as transforms
27
- cap = dset.CocoCaptions(root = 'dir where images are',
28
- annFile = 'json annotation file',
29
- transform=transforms.ToTensor())
30
-
31
- print('Number of samples: ', len(cap))
32
- img, target = cap[3] # load 4th sample
33
-
34
- print("Image Size: ", img.size())
35
- print(target)
36
-
37
- Output: ::
38
-
39
- Number of samples: 82783
40
- Image Size: (3L, 427L, 640L)
41
- [u'A plane emitting smoke stream flying over a mountain.',
42
- u'A plane darts across a bright blue sky behind a mountain covered in snow',
43
- u'A plane leaves a contrail above the snowy mountain top.',
44
- u'A mountain that has a plane flying overheard in the distance.',
45
- u'A mountain view with a plume of smoke in the background']
46
-
47
20
"""
48
21
49
22
def __init__ (
50
- self ,
51
- root : str ,
52
- annFile : str ,
53
- transform : Optional [Callable ] = None ,
54
- target_transform : Optional [Callable ] = None ,
55
- transforms : Optional [Callable ] = None ,
56
- ) -> None :
57
- super (CocoCaptions , self ).__init__ (root , transforms , transform , target_transform )
23
+ self ,
24
+ root : str ,
25
+ annFile : str ,
26
+ transform : Optional [Callable ] = None ,
27
+ target_transform : Optional [Callable ] = None ,
28
+ transforms : Optional [Callable ] = None ,
29
+ ):
30
+ super ().__init__ (root , transforms , transform , target_transform )
58
31
from pycocotools .coco import COCO
32
+
59
33
self .coco = COCO (annFile )
60
34
self .ids = list (sorted (self .coco .imgs .keys ()))
61
35
62
- def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
63
- """
64
- Args:
65
- index (int): Index
66
-
67
- Returns:
68
- tuple: Tuple (image, target). target is a list of captions for the image.
69
- """
70
- coco = self .coco
71
- img_id = self .ids [index ]
72
- ann_ids = coco .getAnnIds (imgIds = img_id )
73
- anns = coco .loadAnns (ann_ids )
74
- target = [ann ['caption' ] for ann in anns ]
36
+ def _load_image (self , id : int ) -> Image .Image :
37
+ path = self .coco .loadImgs (id )[0 ]["file_name" ]
38
+ return Image .open (os .path .join (self .root , path )).convert ("RGB" )
75
39
76
- path = coco .loadImgs (img_id )[0 ]['file_name' ]
40
+ def _load_target (self , id ) -> List [Any ]:
41
+ return self .coco .loadAnns (self .coco .getAnnIds (id ))
77
42
78
- img = Image .open (os .path .join (self .root , path )).convert ('RGB' )
43
+ def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
44
+ id = self .ids [index ]
45
+ image = self ._load_image (id )
46
+ target = self ._load_target (id )
79
47
80
48
if self .transforms is not None :
81
- img , target = self .transforms (img , target )
49
+ image , target = self .transforms (image , target )
82
50
83
- return img , target
51
+ return image , target
84
52
85
53
def __len__ (self ) -> int :
86
54
return len (self .ids )
87
55
88
56
89
- class CocoDetection ( VisionDataset ):
90
- """`MS Coco Detection <https://cocodataset.org/#detection-2016 >`_ Dataset.
57
+ class CocoCaptions ( CocoDetection ):
58
+ """`MS Coco Captions <https://cocodataset.org/#captions-2015 >`_ Dataset.
91
59
92
60
Args:
93
61
root (string): Root directory where images are downloaded to.
@@ -98,41 +66,34 @@ class CocoDetection(VisionDataset):
98
66
target and transforms it.
99
67
transforms (callable, optional): A function/transform that takes input sample and its target as entry
100
68
and returns a transformed version.
101
- """
102
69
103
- def __init__ (
104
- self ,
105
- root : str ,
106
- annFile : str ,
107
- transform : Optional [Callable ] = None ,
108
- target_transform : Optional [Callable ] = None ,
109
- transforms : Optional [Callable ] = None ,
110
- ) -> None :
111
- super (CocoDetection , self ).__init__ (root , transforms , transform , target_transform )
112
- from pycocotools .coco import COCO
113
- self .coco = COCO (annFile )
114
- self .ids = list (sorted (self .coco .imgs .keys ()))
70
+ Example:
115
71
116
- def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
117
- """
118
- Args:
119
- index (int): Index
72
+ .. code:: python
73
+
74
+ import torchvision.datasets as dset
75
+ import torchvision.transforms as transforms
76
+ cap = dset.CocoCaptions(root = 'dir where images are',
77
+ annFile = 'json annotation file',
78
+ transform=transforms.ToTensor())
120
79
121
- Returns:
122
- tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
123
- """
124
- coco = self .coco
125
- img_id = self .ids [index ]
126
- ann_ids = coco .getAnnIds (imgIds = img_id )
127
- target = coco .loadAnns (ann_ids )
80
+ print('Number of samples: ', len(cap))
81
+ img, target = cap[3] # load 4th sample
128
82
129
- path = coco .loadImgs (img_id )[0 ]['file_name' ]
83
+ print("Image Size: ", img.size())
84
+ print(target)
130
85
131
- img = Image .open (os .path .join (self .root , path )).convert ('RGB' )
132
- if self .transforms is not None :
133
- img , target = self .transforms (img , target )
86
+ Output: ::
134
87
135
- return img , target
88
+ Number of samples: 82783
89
+ Image Size: (3L, 427L, 640L)
90
+ [u'A plane emitting smoke stream flying over a mountain.',
91
+ u'A plane darts across a bright blue sky behind a mountain covered in snow',
92
+ u'A plane leaves a contrail above the snowy mountain top.',
93
+ u'A mountain that has a plane flying overheard in the distance.',
94
+ u'A mountain view with a plume of smoke in the background']
136
95
137
- def __len__ (self ) -> int :
138
- return len (self .ids )
96
+ """
97
+
98
+ def _load_target (self , id ) -> List [str ]:
99
+ return [ann ["caption" ] for ann in super ()._load_target (id )]
0 commit comments