@@ -77,6 +77,48 @@ __global__ void nms_kernel_impl(
7777 }
7878}
7979
80+ __global__ static void gather_keep_from_mask (bool *keep,
81+ const unsigned long long *dev_mask,
82+ const int n_boxes) {
83+ // Taken and adapted from mmcv https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
84+ const int col_blocks = ceil_div (n_boxes, threadsPerBlock);
85+ const int thread_id = threadIdx .x ;
86+
87+ // mark the bboxes which have been removed.
88+ extern __shared__ unsigned long long removed[];
89+
90+ // initialize removed.
91+ for (int i = thread_id; i < col_blocks; i += blockDim .x ) {
92+ removed[i] = 0 ;
93+ }
94+ __syncthreads ();
95+
96+ for (int nblock = 0 ; nblock < col_blocks; nblock++) {
97+ auto removed_val = removed[nblock];
98+ __syncthreads ();
99+ const int i_offset = nblock * threadsPerBlock;
100+ #pragma unroll
101+ for (int inblock = 0 ; inblock < threadsPerBlock; inblock++) {
102+ const int i = i_offset + inblock;
103+ if (i >= n_boxes) break ;
104+ // select a candidate, check if it should kept.
105+ if (!(removed_val & (1ULL << inblock))) {
106+ if (thread_id == 0 ) {
107+ // mark the output.
108+ keep[i] = true ;
109+ }
110+ auto p = dev_mask + i * col_blocks;
111+ // remove all bboxes which overlap the candidate.
112+ for (int j = thread_id; j < col_blocks; j += blockDim .x ) {
113+ if (j >= nblock) removed[j] |= p[j];
114+ }
115+ __syncthreads ();
116+ removed_val = removed[nblock];
117+ }
118+ }
119+ }
120+ }
121+
80122at::Tensor nms_kernel (
81123 const at::Tensor& dets,
82124 const at::Tensor& scores,
@@ -133,35 +175,23 @@ at::Tensor nms_kernel(
133175 (unsigned long long *)mask.data_ptr <int64_t >());
134176 });
135177
136- at::Tensor mask_cpu = mask.to (at::kCPU );
137- unsigned long long * mask_host =
138- (unsigned long long *)mask_cpu.data_ptr <int64_t >();
139-
140- std::vector<unsigned long long > remv (col_blocks);
141- memset (&remv[0 ], 0 , sizeof (unsigned long long ) * col_blocks);
142-
143- at::Tensor keep =
144- at::empty ({dets_num}, dets.options ().dtype (at::kLong ).device (at::kCPU ));
145- int64_t * keep_out = keep.data_ptr <int64_t >();
146-
147- int num_to_keep = 0 ;
148- for (int i = 0 ; i < dets_num; i++) {
149- int nblock = i / threadsPerBlock;
150- int inblock = i % threadsPerBlock;
151-
152- if (!(remv[nblock] & (1ULL << inblock))) {
153- keep_out[num_to_keep++] = i;
154- unsigned long long * p = mask_host + i * col_blocks;
155- for (int j = nblock; j < col_blocks; j++) {
156- remv[j] |= p[j];
157- }
158- }
159- }
178+ at::Tensor keep = at::zeros (
179+ {dets_num},
180+ dets.options ().dtype (at::kBool ).device (at::kCUDA )
181+ );
182+
183+ // Unwrap the mask to fill keep with proper values
184+ // Keeping this unwrap on cuda instead of applying iterative for loops on cpu
185+ // prevents the device -> cpu -> device transfer that could be bottleneck for
186+ // large number of boxes.
187+ // See https://github.com/pytorch/vision/issues/8713 for more details
188+ gather_keep_from_mask<<<1 , min(col_blocks, threadsPerBlock),
189+ col_blocks * sizeof (unsigned long long ), stream>>> (
190+ keep.data_ptr <bool >(), (unsigned long long *)mask.data_ptr <int64_t >(),
191+ dets_num);
160192
161193 AT_CUDA_CHECK (cudaGetLastError ());
162- return order_t .index (
163- {keep.narrow (/* dim=*/ 0 , /* start=*/ 0 , /* length=*/ num_to_keep)
164- .to (order_t .device (), keep.scalar_type ())});
194+ return order_t .masked_select (keep);
165195}
166196
167197} // namespace
0 commit comments