Skip to content

Commit 06cbdb5

Browse files
ppwwyyxxfmassa
authored andcommitted
Speed up nms_cuda (#1704)
1. Let the IOU function compare with threshold. This avoid a division. Similar strategy is also used in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/non_max_suppression_op.cu.cc 2. Only compute the upper triangle of the mask. This speeds up the kernel about 20% (tested on GTX 1080Ti, with 20 input cases dumped from a Mask R-CNN inference job).
1 parent 2a17422 commit 06cbdb5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchvision/csrc/cuda/nms_cuda.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
int const threadsPerBlock = sizeof(unsigned long long) * 8;
1212

1313
template <typename T>
14-
__device__ inline float devIoU(T const* const a, T const* const b) {
14+
__device__ inline bool devIoU(T const* const a, T const* const b, const float threshold) {
1515
T left = max(a[0], b[0]), right = min(a[2], b[2]);
1616
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
1717
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
1818
T interS = width * height;
1919
T Sa = (a[2] - a[0]) * (a[3] - a[1]);
2020
T Sb = (b[2] - b[0]) * (b[3] - b[1]);
21-
return interS / (Sa + Sb - interS);
21+
return interS > threshold * (Sa + Sb - interS);
2222
}
2323

2424
template <typename T>
@@ -30,7 +30,7 @@ __global__ void nms_kernel(
3030
const int row_start = blockIdx.y;
3131
const int col_start = blockIdx.x;
3232

33-
// if (row_start > col_start) return;
33+
if (row_start > col_start) return;
3434

3535
const int row_size =
3636
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
@@ -60,7 +60,7 @@ __global__ void nms_kernel(
6060
start = threadIdx.x + 1;
6161
}
6262
for (i = start; i < col_size; i++) {
63-
if (devIoU<T>(cur_box, block_boxes + i * 4) > iou_threshold) {
63+
if (devIoU<T>(cur_box, block_boxes + i * 4, iou_threshold)) {
6464
t |= 1ULL << i;
6565
}
6666
}

0 commit comments

Comments
 (0)