@@ -75,6 +75,51 @@ __global__ void nms_kernel_impl(
7575 }
7676}
7777
78+ __global__ static void gather_keep_from_mask (
79+ bool * keep,
80+ const unsigned long long * dev_mask,
81+ const int n_boxes) {
82+ // Taken and adapted from mmcv
83+ // https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
84+ const int col_blocks = ceil_div (n_boxes, threadsPerBlock);
85+ const int thread_id = threadIdx .x ;
86+
87+ // Mark the bboxes which have been removed.
88+ extern __shared__ unsigned long long removed[];
89+
90+ // Initialize removed.
91+ for (int i = thread_id; i < col_blocks; i += blockDim .x ) {
92+ removed[i] = 0 ;
93+ }
94+ __syncthreads ();
95+
96+ for (int nblock = 0 ; nblock < col_blocks; nblock++) {
97+ auto removed_val = removed[nblock];
98+ __syncthreads ();
99+ const int i_offset = nblock * threadsPerBlock;
100+ #pragma unroll
101+ for (int inblock = 0 ; inblock < threadsPerBlock; inblock++) {
102+ const int i = i_offset + inblock;
103+ if (i >= n_boxes)
104+ break ;
105+ // Select a candidate, check if it should kept.
106+ if (!(removed_val & (1ULL << inblock))) {
107+ if (thread_id == 0 ) {
108+ keep[i] = true ;
109+ }
110+ auto p = dev_mask + i * col_blocks;
111+ // Remove all bboxes which overlap the candidate.
112+ for (int j = thread_id; j < col_blocks; j += blockDim .x ) {
113+ if (j >= nblock)
114+ removed[j] |= p[j];
115+ }
116+ __syncthreads ();
117+ removed_val = removed[nblock];
118+ }
119+ }
120+ }
121+ }
122+
78123at::Tensor nms_kernel (
79124 const at::Tensor& dets,
80125 const at::Tensor& scores,
@@ -131,35 +176,25 @@ at::Tensor nms_kernel(
131176 (unsigned long long *)mask.data_ptr <int64_t >());
132177 });
133178
134- at::Tensor mask_cpu = mask.to (at::kCPU );
135- unsigned long long * mask_host =
136- (unsigned long long *)mask_cpu.data_ptr <int64_t >();
137-
138- std::vector<unsigned long long > remv (col_blocks);
139- memset (&remv[0 ], 0 , sizeof (unsigned long long ) * col_blocks);
140-
141179 at::Tensor keep =
142- at::empty ({dets_num}, dets.options ().dtype (at::kLong ).device (at::kCPU ));
143- int64_t * keep_out = keep.data_ptr <int64_t >();
144-
145- int num_to_keep = 0 ;
146- for (int i = 0 ; i < dets_num; i++) {
147- int nblock = i / threadsPerBlock;
148- int inblock = i % threadsPerBlock;
149-
150- if (!(remv[nblock] & (1ULL << inblock))) {
151- keep_out[num_to_keep++] = i;
152- unsigned long long * p = mask_host + i * col_blocks;
153- for (int j = nblock; j < col_blocks; j++) {
154- remv[j] |= p[j];
155- }
156- }
157- }
180+ at::zeros ({dets_num}, dets.options ().dtype (at::kBool ).device (at::kCUDA ));
181+
182+ // Unwrap the mask to fill keep with proper values
183+ // Keeping the unwrap on device instead of applying iterative for loops on cpu
184+ // prevents the device -> cpu -> device transfer that could be bottleneck for
185+ // large number of boxes.
186+ // See https://github.com/pytorch/vision/issues/8713 for more details.
187+ gather_keep_from_mask<<<
188+ 1 ,
189+ min (col_blocks, threadsPerBlock),
190+ col_blocks * sizeof(unsigned long long ),
191+ stream>>>(
192+ keep.data_ptr<bool >(),
193+ (unsigned long long *)mask.data_ptr<int64_t>(),
194+ dets_num);
158195
159196 AT_CUDA_CHECK (cudaGetLastError());
160- return order_t .index (
161- {keep.narrow (/* dim=*/ 0 , /* start=*/ 0 , /* length=*/ num_to_keep)
162- .to (order_t .device (), keep.scalar_type ())});
197+ return order_t .masked_select(keep);
163198}
164199
165200} // namespace
0 commit comments