@@ -69,10 +69,11 @@ def batched_nms(
6969 _log_api_usage_once (batched_nms )
7070 # Benchmarks that drove the following thresholds are at
7171 # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
72- if boxes .numel () > (4000 if boxes .device .type == "cpu" else 20000 ) and not torchvision ._is_tracing ():
73- return _batched_nms_vanilla (boxes , scores , idxs , iou_threshold )
74- else :
75- return _batched_nms_coordinate_trick (boxes , scores , idxs , iou_threshold )
72+ return _batched_nms_vanilla (boxes , scores , idxs , iou_threshold )
73+ #if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
74+ # return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
75+ #else:
76+ # return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
7677
7778
7879@torch .jit ._script_if_tracing
@@ -104,7 +105,8 @@ def _batched_nms_vanilla(
104105) -> Tensor :
105106 # Based on Detectron2 implementation, just manually call nms() on each class independently
106107 keep_mask = torch .zeros_like (scores , dtype = torch .bool )
107- for class_id in torch .unique (idxs ):
108+ #for class_id in torch.unique(idxs):
109+ for class_id in idxs :
108110 curr_indices = torch .where (idxs == class_id )[0 ]
109111 curr_keep_indices = nms (boxes [curr_indices ], scores [curr_indices ], iou_threshold )
110112 keep_mask [curr_indices [curr_keep_indices ]] = True
0 commit comments