Skip to content

Commit c1a0ca2

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Speed-up NMS by keeping index gathering on cuda device (#8766)
Reviewed By: scotts Differential Revision: D77997069 fbshipit-source-id: aea325c752fac6be6d69d107c9506ff3adeefb38 Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 960d94a commit c1a0ca2

File tree

1 file changed

+61
-26
lines changed

1 file changed

+61
-26
lines changed

torchvision/csrc/ops/cuda/nms_kernel.cu

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,51 @@ __global__ void nms_kernel_impl(
7575
}
7676
}
7777

78+
__global__ static void gather_keep_from_mask(
79+
bool* keep,
80+
const unsigned long long* dev_mask,
81+
const int n_boxes) {
82+
// Taken and adapted from mmcv
83+
// https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
84+
const int col_blocks = ceil_div(n_boxes, threadsPerBlock);
85+
const int thread_id = threadIdx.x;
86+
87+
// Mark the bboxes which have been removed.
88+
extern __shared__ unsigned long long removed[];
89+
90+
// Initialize removed.
91+
for (int i = thread_id; i < col_blocks; i += blockDim.x) {
92+
removed[i] = 0;
93+
}
94+
__syncthreads();
95+
96+
for (int nblock = 0; nblock < col_blocks; nblock++) {
97+
auto removed_val = removed[nblock];
98+
__syncthreads();
99+
const int i_offset = nblock * threadsPerBlock;
100+
#pragma unroll
101+
for (int inblock = 0; inblock < threadsPerBlock; inblock++) {
102+
const int i = i_offset + inblock;
103+
if (i >= n_boxes)
104+
break;
105+
// Select a candidate, check if it should kept.
106+
if (!(removed_val & (1ULL << inblock))) {
107+
if (thread_id == 0) {
108+
keep[i] = true;
109+
}
110+
auto p = dev_mask + i * col_blocks;
111+
// Remove all bboxes which overlap the candidate.
112+
for (int j = thread_id; j < col_blocks; j += blockDim.x) {
113+
if (j >= nblock)
114+
removed[j] |= p[j];
115+
}
116+
__syncthreads();
117+
removed_val = removed[nblock];
118+
}
119+
}
120+
}
121+
}
122+
78123
at::Tensor nms_kernel(
79124
const at::Tensor& dets,
80125
const at::Tensor& scores,
@@ -131,35 +176,25 @@ at::Tensor nms_kernel(
131176
(unsigned long long*)mask.data_ptr<int64_t>());
132177
});
133178

134-
at::Tensor mask_cpu = mask.to(at::kCPU);
135-
unsigned long long* mask_host =
136-
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
137-
138-
std::vector<unsigned long long> remv(col_blocks);
139-
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
140-
141179
at::Tensor keep =
142-
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
143-
int64_t* keep_out = keep.data_ptr<int64_t>();
144-
145-
int num_to_keep = 0;
146-
for (int i = 0; i < dets_num; i++) {
147-
int nblock = i / threadsPerBlock;
148-
int inblock = i % threadsPerBlock;
149-
150-
if (!(remv[nblock] & (1ULL << inblock))) {
151-
keep_out[num_to_keep++] = i;
152-
unsigned long long* p = mask_host + i * col_blocks;
153-
for (int j = nblock; j < col_blocks; j++) {
154-
remv[j] |= p[j];
155-
}
156-
}
157-
}
180+
at::zeros({dets_num}, dets.options().dtype(at::kBool).device(at::kCUDA));
181+
182+
// Unwrap the mask to fill keep with proper values
183+
// Keeping the unwrap on device instead of applying iterative for loops on cpu
184+
// prevents the device -> cpu -> device transfer that could be bottleneck for
185+
// large number of boxes.
186+
// See https://github.com/pytorch/vision/issues/8713 for more details.
187+
gather_keep_from_mask<<<
188+
1,
189+
min(col_blocks, threadsPerBlock),
190+
col_blocks * sizeof(unsigned long long),
191+
stream>>>(
192+
keep.data_ptr<bool>(),
193+
(unsigned long long*)mask.data_ptr<int64_t>(),
194+
dets_num);
158195
159196
AT_CUDA_CHECK(cudaGetLastError());
160-
return order_t.index(
161-
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
162-
.to(order_t.device(), keep.scalar_type())});
197+
return order_t.masked_select(keep);
163198
}
164199
165200
} // namespace

0 commit comments

Comments
 (0)