Skip to content

Commit 62e185c

Browse files
pmeierfmassa
andauthored
improve Coco implementation (#3417)
Co-authored-by: Francisco Massa <[email protected]>
1 parent a6f3f95 commit 62e185c

File tree

1 file changed

+49
-88
lines changed

1 file changed

+49
-88
lines changed

torchvision/datasets/coco.py

Lines changed: 49 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
from PIL import Image
33
import os
44
import os.path
5-
from typing import Any, Callable, Optional, Tuple
5+
from typing import Any, Callable, Optional, Tuple, List
66

77

8-
class CocoCaptions(VisionDataset):
9-
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
8+
class CocoDetection(VisionDataset):
9+
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
1010
1111
Args:
1212
root (string): Root directory where images are downloaded to.
@@ -17,77 +17,45 @@ class CocoCaptions(VisionDataset):
1717
target and transforms it.
1818
transforms (callable, optional): A function/transform that takes input sample and its target as entry
1919
and returns a transformed version.
20-
21-
Example:
22-
23-
.. code:: python
24-
25-
import torchvision.datasets as dset
26-
import torchvision.transforms as transforms
27-
cap = dset.CocoCaptions(root = 'dir where images are',
28-
annFile = 'json annotation file',
29-
transform=transforms.ToTensor())
30-
31-
print('Number of samples: ', len(cap))
32-
img, target = cap[3] # load 4th sample
33-
34-
print("Image Size: ", img.size())
35-
print(target)
36-
37-
Output: ::
38-
39-
Number of samples: 82783
40-
Image Size: (3L, 427L, 640L)
41-
[u'A plane emitting smoke stream flying over a mountain.',
42-
u'A plane darts across a bright blue sky behind a mountain covered in snow',
43-
u'A plane leaves a contrail above the snowy mountain top.',
44-
u'A mountain that has a plane flying overheard in the distance.',
45-
u'A mountain view with a plume of smoke in the background']
46-
4720
"""
4821

4922
def __init__(
50-
self,
51-
root: str,
52-
annFile: str,
53-
transform: Optional[Callable] = None,
54-
target_transform: Optional[Callable] = None,
55-
transforms: Optional[Callable] = None,
56-
) -> None:
57-
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
23+
self,
24+
root: str,
25+
annFile: str,
26+
transform: Optional[Callable] = None,
27+
target_transform: Optional[Callable] = None,
28+
transforms: Optional[Callable] = None,
29+
):
30+
super().__init__(root, transforms, transform, target_transform)
5831
from pycocotools.coco import COCO
32+
5933
self.coco = COCO(annFile)
6034
self.ids = list(sorted(self.coco.imgs.keys()))
6135

62-
def __getitem__(self, index: int) -> Tuple[Any, Any]:
63-
"""
64-
Args:
65-
index (int): Index
66-
67-
Returns:
68-
tuple: Tuple (image, target). target is a list of captions for the image.
69-
"""
70-
coco = self.coco
71-
img_id = self.ids[index]
72-
ann_ids = coco.getAnnIds(imgIds=img_id)
73-
anns = coco.loadAnns(ann_ids)
74-
target = [ann['caption'] for ann in anns]
36+
def _load_image(self, id: int) -> Image.Image:
37+
path = self.coco.loadImgs(id)[0]["file_name"]
38+
return Image.open(os.path.join(self.root, path)).convert("RGB")
7539

76-
path = coco.loadImgs(img_id)[0]['file_name']
40+
def _load_target(self, id) -> List[Any]:
41+
return self.coco.loadAnns(self.coco.getAnnIds(id))
7742

78-
img = Image.open(os.path.join(self.root, path)).convert('RGB')
43+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
44+
id = self.ids[index]
45+
image = self._load_image(id)
46+
target = self._load_target(id)
7947

8048
if self.transforms is not None:
81-
img, target = self.transforms(img, target)
49+
image, target = self.transforms(image, target)
8250

83-
return img, target
51+
return image, target
8452

8553
def __len__(self) -> int:
8654
return len(self.ids)
8755

8856

89-
class CocoDetection(VisionDataset):
90-
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
57+
class CocoCaptions(CocoDetection):
58+
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
9159
9260
Args:
9361
root (string): Root directory where images are downloaded to.
@@ -98,41 +66,34 @@ class CocoDetection(VisionDataset):
9866
target and transforms it.
9967
transforms (callable, optional): A function/transform that takes input sample and its target as entry
10068
and returns a transformed version.
101-
"""
10269
103-
def __init__(
104-
self,
105-
root: str,
106-
annFile: str,
107-
transform: Optional[Callable] = None,
108-
target_transform: Optional[Callable] = None,
109-
transforms: Optional[Callable] = None,
110-
) -> None:
111-
super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
112-
from pycocotools.coco import COCO
113-
self.coco = COCO(annFile)
114-
self.ids = list(sorted(self.coco.imgs.keys()))
70+
Example:
11571
116-
def __getitem__(self, index: int) -> Tuple[Any, Any]:
117-
"""
118-
Args:
119-
index (int): Index
72+
.. code:: python
73+
74+
import torchvision.datasets as dset
75+
import torchvision.transforms as transforms
76+
cap = dset.CocoCaptions(root = 'dir where images are',
77+
annFile = 'json annotation file',
78+
transform=transforms.ToTensor())
12079
121-
Returns:
122-
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
123-
"""
124-
coco = self.coco
125-
img_id = self.ids[index]
126-
ann_ids = coco.getAnnIds(imgIds=img_id)
127-
target = coco.loadAnns(ann_ids)
80+
print('Number of samples: ', len(cap))
81+
img, target = cap[3] # load 4th sample
12882
129-
path = coco.loadImgs(img_id)[0]['file_name']
83+
print("Image Size: ", img.size())
84+
print(target)
13085
131-
img = Image.open(os.path.join(self.root, path)).convert('RGB')
132-
if self.transforms is not None:
133-
img, target = self.transforms(img, target)
86+
Output: ::
13487
135-
return img, target
88+
Number of samples: 82783
89+
Image Size: (3L, 427L, 640L)
90+
[u'A plane emitting smoke stream flying over a mountain.',
91+
u'A plane darts across a bright blue sky behind a mountain covered in snow',
92+
u'A plane leaves a contrail above the snowy mountain top.',
93+
u'A mountain that has a plane flying overheard in the distance.',
94+
u'A mountain view with a plume of smoke in the background']
13695
137-
def __len__(self) -> int:
138-
return len(self.ids)
96+
"""
97+
98+
def _load_target(self, id) -> List[str]:
99+
return [ann["caption"] for ann in super()._load_target(id)]

0 commit comments

Comments
 (0)