Skip to content

Commit 9ffeeb5

Browse files
authored
Add support for evaluating masks and keypoints for custom dataset (#938)
1 parent f4d43cc commit 9ffeeb5

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

references/detection/coco_utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,18 +151,26 @@ def convert_to_coco_api(ds):
151151
for img_idx in range(len(ds)):
152152
# find better way to get target
153153
# targets = ds.get_annotations(img_idx)
154-
_, targets = ds[img_idx]
154+
img, targets = ds[img_idx]
155155
image_id = targets["image_id"].item()
156156
img_dict = {}
157157
img_dict['id'] = image_id
158+
img_dict['height'] = img.shape[-2]
159+
img_dict['width'] = img.shape[-1]
158160
dataset['images'].append(img_dict)
159161
bboxes = targets["boxes"]
160162
bboxes[:, 2:] -= bboxes[:, :2]
161163
bboxes = bboxes.tolist()
162164
labels = targets['labels'].tolist()
163165
areas = targets['area'].tolist()
164166
iscrowd = targets['iscrowd'].tolist()
165-
# TODO need to add masks as well
167+
if 'masks' in targets:
168+
masks = targets['masks']
169+
# make masks Fortran contiguous for coco_mask
170+
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
171+
if 'keypoints' in targets:
172+
keypoints = targets['keypoints']
173+
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
166174
num_objs = len(bboxes)
167175
for i in range(num_objs):
168176
ann = {}
@@ -173,6 +181,11 @@ def convert_to_coco_api(ds):
173181
ann['area'] = areas[i]
174182
ann['iscrowd'] = iscrowd[i]
175183
ann['id'] = ann_id
184+
if 'masks' in targets:
185+
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
186+
if 'keypoints' in targets:
187+
ann['keypoints'] = keypoints[i]
188+
ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3])
176189
dataset['annotations'].append(ann)
177190
ann_id += 1
178191
dataset['categories'] = [{'id': i} for i in sorted(categories)]

0 commit comments

Comments
 (0)