Skip to content

Commit 5c2a5bd

Browse files
committed
Keep NMS index gathering on cuda device
1 parent fab1188 commit 5c2a5bd

File tree

1 file changed

+57
-27
lines changed

1 file changed

+57
-27
lines changed

torchvision/csrc/ops/cuda/nms_kernel.cu

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,48 @@ __global__ void nms_kernel_impl(
7777
}
7878
}
7979

80+
__global__ static void gather_keep_from_mask(bool *keep,
81+
const unsigned long long *dev_mask,
82+
const int n_boxes) {
83+
// Taken and adapted from mmcv https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
84+
const int col_blocks = ceil_div(n_boxes, threadsPerBlock);
85+
const int thread_id = threadIdx.x;
86+
87+
// mark the bboxes which have been removed.
88+
extern __shared__ unsigned long long removed[];
89+
90+
// initialize removed.
91+
for (int i = thread_id; i < col_blocks; i += blockDim.x) {
92+
removed[i] = 0;
93+
}
94+
__syncthreads();
95+
96+
for (int nblock = 0; nblock < col_blocks; nblock++) {
97+
auto removed_val = removed[nblock];
98+
__syncthreads();
99+
const int i_offset = nblock * threadsPerBlock;
100+
#pragma unroll
101+
for (int inblock = 0; inblock < threadsPerBlock; inblock++) {
102+
const int i = i_offset + inblock;
103+
if (i >= n_boxes) break;
104+
// select a candidate, check if it should kept.
105+
if (!(removed_val & (1ULL << inblock))) {
106+
if (thread_id == 0) {
107+
// mark the output.
108+
keep[i] = true;
109+
}
110+
auto p = dev_mask + i * col_blocks;
111+
// remove all bboxes which overlap the candidate.
112+
for (int j = thread_id; j < col_blocks; j += blockDim.x) {
113+
if (j >= nblock) removed[j] |= p[j];
114+
}
115+
__syncthreads();
116+
removed_val = removed[nblock];
117+
}
118+
}
119+
}
120+
}
121+
80122
at::Tensor nms_kernel(
81123
const at::Tensor& dets,
82124
const at::Tensor& scores,
@@ -133,35 +175,23 @@ at::Tensor nms_kernel(
133175
(unsigned long long*)mask.data_ptr<int64_t>());
134176
});
135177

136-
at::Tensor mask_cpu = mask.to(at::kCPU);
137-
unsigned long long* mask_host =
138-
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
139-
140-
std::vector<unsigned long long> remv(col_blocks);
141-
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
142-
143-
at::Tensor keep =
144-
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
145-
int64_t* keep_out = keep.data_ptr<int64_t>();
146-
147-
int num_to_keep = 0;
148-
for (int i = 0; i < dets_num; i++) {
149-
int nblock = i / threadsPerBlock;
150-
int inblock = i % threadsPerBlock;
151-
152-
if (!(remv[nblock] & (1ULL << inblock))) {
153-
keep_out[num_to_keep++] = i;
154-
unsigned long long* p = mask_host + i * col_blocks;
155-
for (int j = nblock; j < col_blocks; j++) {
156-
remv[j] |= p[j];
157-
}
158-
}
159-
}
178+
at::Tensor keep = at::zeros(
179+
{dets_num},
180+
dets.options().dtype(at::kBool).device(at::kCUDA)
181+
);
182+
183+
// Unwrap the mask to fill keep with proper values
184+
// Keeping this unwrap on cuda instead of applying iterative for loops on cpu
185+
// prevents the device -> cpu -> device transfer that could be bottleneck for
186+
// large number of boxes.
187+
// See https://github.com/pytorch/vision/issues/8713 for more details
188+
gather_keep_from_mask<<<1, min(col_blocks, threadsPerBlock),
189+
col_blocks * sizeof(unsigned long long), stream>>>(
190+
keep.data_ptr<bool>(), (unsigned long long*)mask.data_ptr<int64_t>(),
191+
dets_num);
160192

161193
AT_CUDA_CHECK(cudaGetLastError());
162-
return order_t.index(
163-
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
164-
.to(order_t.device(), keep.scalar_type())});
194+
return order_t.masked_select(keep);
165195
}
166196

167197
} // namespace

0 commit comments

Comments
 (0)