Skip to content

Commit e239710

Browse files
GhelfiNicolasHug
andauthored
Speed-up NMS by keeping index gathering on cuda device (#8766)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent ae9bd7e commit e239710

File tree

1 file changed

+61
-26
lines changed

1 file changed

+61
-26
lines changed

torchvision/csrc/ops/cuda/nms_kernel.cu

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,51 @@ __global__ void nms_kernel_impl(
7777
}
7878
}
7979

80+
__global__ static void gather_keep_from_mask(
81+
bool* keep,
82+
const unsigned long long* dev_mask,
83+
const int n_boxes) {
84+
// Taken and adapted from mmcv
85+
// https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
86+
const int col_blocks = ceil_div(n_boxes, threadsPerBlock);
87+
const int thread_id = threadIdx.x;
88+
89+
// Mark the bboxes which have been removed.
90+
extern __shared__ unsigned long long removed[];
91+
92+
// Initialize removed.
93+
for (int i = thread_id; i < col_blocks; i += blockDim.x) {
94+
removed[i] = 0;
95+
}
96+
__syncthreads();
97+
98+
for (int nblock = 0; nblock < col_blocks; nblock++) {
99+
auto removed_val = removed[nblock];
100+
__syncthreads();
101+
const int i_offset = nblock * threadsPerBlock;
102+
#pragma unroll
103+
for (int inblock = 0; inblock < threadsPerBlock; inblock++) {
104+
const int i = i_offset + inblock;
105+
if (i >= n_boxes)
106+
break;
107+
// Select a candidate, check if it should kept.
108+
if (!(removed_val & (1ULL << inblock))) {
109+
if (thread_id == 0) {
110+
keep[i] = true;
111+
}
112+
auto p = dev_mask + i * col_blocks;
113+
// Remove all bboxes which overlap the candidate.
114+
for (int j = thread_id; j < col_blocks; j += blockDim.x) {
115+
if (j >= nblock)
116+
removed[j] |= p[j];
117+
}
118+
__syncthreads();
119+
removed_val = removed[nblock];
120+
}
121+
}
122+
}
123+
}
124+
80125
at::Tensor nms_kernel(
81126
const at::Tensor& dets,
82127
const at::Tensor& scores,
@@ -133,35 +178,25 @@ at::Tensor nms_kernel(
133178
(unsigned long long*)mask.data_ptr<int64_t>());
134179
});
135180

136-
at::Tensor mask_cpu = mask.to(at::kCPU);
137-
unsigned long long* mask_host =
138-
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
139-
140-
std::vector<unsigned long long> remv(col_blocks);
141-
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
142-
143181
at::Tensor keep =
144-
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
145-
int64_t* keep_out = keep.data_ptr<int64_t>();
146-
147-
int num_to_keep = 0;
148-
for (int i = 0; i < dets_num; i++) {
149-
int nblock = i / threadsPerBlock;
150-
int inblock = i % threadsPerBlock;
151-
152-
if (!(remv[nblock] & (1ULL << inblock))) {
153-
keep_out[num_to_keep++] = i;
154-
unsigned long long* p = mask_host + i * col_blocks;
155-
for (int j = nblock; j < col_blocks; j++) {
156-
remv[j] |= p[j];
157-
}
158-
}
159-
}
182+
at::zeros({dets_num}, dets.options().dtype(at::kBool).device(at::kCUDA));
183+
184+
// Unwrap the mask to fill keep with proper values
185+
// Keeping the unwrap on device instead of applying iterative for loops on cpu
186+
// prevents the device -> cpu -> device transfer that could be bottleneck for
187+
// large number of boxes.
188+
// See https://github.com/pytorch/vision/issues/8713 for more details.
189+
gather_keep_from_mask<<<
190+
1,
191+
min(col_blocks, threadsPerBlock),
192+
col_blocks * sizeof(unsigned long long),
193+
stream>>>(
194+
keep.data_ptr<bool>(),
195+
(unsigned long long*)mask.data_ptr<int64_t>(),
196+
dets_num);
160197
161198
AT_CUDA_CHECK(cudaGetLastError());
162-
return order_t.index(
163-
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
164-
.to(order_t.device(), keep.scalar_type())});
199+
return order_t.masked_select(keep);
165200
}
166201
167202
} // namespace

0 commit comments

Comments
 (0)