diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 9674d5bfa1d..48df4d85cc7 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -78,7 +78,8 @@ def batched_nms( _log_api_usage_once(batched_nms) # Benchmarks that drove the following thresholds are at # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339 - if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing(): + # and https://github.com/pytorch/vision/pull/8925 + if boxes.numel() > (4000 if boxes.device.type == "cpu" else 100_000) and not torchvision._is_tracing(): return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) else: return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)