11import torch
2- from typing import List
2+ from typing import List , Optional
33import numpy as np
44import numba
55
88
99
1010@numba .jit (nopython = True , parallel = True )
11- def _instance_iou_cpu (instance_idx , instance_offsets , gt_instances , gt_instance_sizes , num_gt_instances ):
11+ def _instance_iou_cpu (
12+ instance_idx , instance_offsets , gt_instances , gt_instance_sizes , num_gt_instances : np .array , batch : np .array ,
13+ ):
1214 num_proposed_instances = len (instance_offsets ) - 1
13- iou = np .zeros ((num_proposed_instances , num_gt_instances ))
15+ iou = np .zeros ((num_proposed_instances , num_gt_instances .sum ()))
16+ offset_num_gt_instances = np .concatenate ((np .array ([0 ]), num_gt_instances .cumsum ()))
1417 for proposed_instance in range (num_proposed_instances ):
1518 instance = instance_idx [instance_offsets [proposed_instance ] : instance_offsets [proposed_instance + 1 ]]
16- for instance_id in numba .prange (1 , num_gt_instances + 1 ):
19+ sample_idx = batch [instance [0 ]]
20+ gt_count_offset = offset_num_gt_instances [sample_idx ]
21+ sample_instance_count = num_gt_instances [sample_idx ]
22+ for instance_id in numba .prange (1 , sample_instance_count + 1 ):
1723 intersection = 0
1824 for idx in instance :
1925 if gt_instances [idx ] == instance_id :
2026 intersection += 1
21- iou [proposed_instance , instance_id - 1 ] = intersection / float (
22- len (instance ) + gt_instance_sizes [instance_id - 1 ] - intersection
27+ iou [proposed_instance , gt_count_offset + instance_id - 1 ] = intersection / float (
28+ len (instance ) + gt_instance_sizes [gt_count_offset + instance_id - 1 ] - intersection
2329 )
2430 return iou
2531
2632
27- def instance_iou (instance_idx : List [torch .Tensor ], gt_instances : torch .Tensor ):
33+ def instance_iou (
34+ instance_idx : List [torch .Tensor ], gt_instances : torch .Tensor , batch : Optional [torch .Tensor ] = None ,
35+ ):
2836 """ Computes the IoU between each proposed instance in instance_idx and ground truth instances. Returns a
2937 tensor of shape [instance_idx.shape[0], num_instances] that contains the iou between the proposed instances and all gt instances
3038 Instance label 0 is reserved for non instance points
@@ -41,29 +49,48 @@ def instance_iou(instance_idx: List[torch.Tensor], gt_instances: torch.Tensor):
4149 -------
4250 ious: torch.Tensor[nb_proposals, nb_groundtruth]
4351 """
52+ if batch is None :
53+ batch = torch .zeros_like (gt_instances )
54+
55+ # Gather number of gt instances per batch and size of those instances
4456 gt_instance_sizes = []
45- num_gt_instances = torch .max (gt_instances ).item ()
46- for instance_id in range (1 , num_gt_instances + 1 ):
47- gt_instance_sizes .append (torch .sum (gt_instances == instance_id ))
57+ num_gt_instances = []
58+ batch_size = batch [- 1 ] + 1
59+ for s in range (batch_size ):
60+ batch_mask = batch == s
61+ sample_gt_instances = gt_instances [batch_mask ]
62+ sample_num_gt_instances = torch .max (sample_gt_instances ).item ()
63+ num_gt_instances .append (sample_num_gt_instances )
64+ for instance_id in range (1 , sample_num_gt_instances + 1 ):
65+ gt_instance_sizes .append (torch .sum (sample_gt_instances == instance_id ))
4866 gt_instance_sizes = torch .stack (gt_instance_sizes )
67+ num_gt_instances = torch .tensor (num_gt_instances )
4968
69+ # Instance offset when flatten
5070 instance_offsets = [0 ]
5171 cum_offset = 0
5272 for instance in instance_idx :
5373 cum_offset += instance .shape [0 ]
5474 instance_offsets .append (cum_offset )
5575
76+ # Compute ious
5677 instance_idx = torch .cat (instance_idx )
5778 if gt_instances .is_cuda :
5879 return tpcuda .instance_iou_cuda (
59- instance_idx , torch .tensor (instance_offsets ).cuda (), gt_instances , gt_instance_sizes , num_gt_instances ,
80+ instance_idx .cuda (),
81+ torch .tensor (instance_offsets ).cuda (),
82+ gt_instances .cuda (),
83+ gt_instance_sizes .cuda (),
84+ num_gt_instances .cuda (),
85+ batch .cuda (),
6086 )
6187 else :
6288 res = _instance_iou_cpu (
6389 instance_idx .numpy (),
6490 np .asarray (instance_offsets ),
6591 gt_instances .numpy (),
6692 gt_instance_sizes .numpy (),
67- num_gt_instances ,
93+ num_gt_instances .numpy (),
94+ batch .numpy (),
6895 )
6996 return torch .tensor (res ).float ()
0 commit comments