@@ -77,10 +77,12 @@ __global__ void nms_kernel_impl(
7777 }
7878}
7979
80- __global__ static void gather_keep_from_mask (bool *keep,
81- const unsigned long long *dev_mask,
82- const int n_boxes) {
83- // Taken and adapted from mmcv https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
80+ __global__ static void gather_keep_from_mask (
81+ bool * keep,
82+ const unsigned long long * dev_mask,
83+ const int n_boxes) {
84+ // Taken and adapted from mmcv
85+ // https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
8486 const int col_blocks = ceil_div (n_boxes, threadsPerBlock);
8587 const int thread_id = threadIdx .x ;
8688
@@ -97,10 +99,11 @@ __global__ static void gather_keep_from_mask(bool *keep,
9799 auto removed_val = removed[nblock];
98100 __syncthreads ();
99101 const int i_offset = nblock * threadsPerBlock;
100- #pragma unroll
102+ #pragma unroll
101103 for (int inblock = 0 ; inblock < threadsPerBlock; inblock++) {
102104 const int i = i_offset + inblock;
103- if (i >= n_boxes) break ;
105+ if (i >= n_boxes)
106+ break ;
104107 // Select a candidate, check if it should kept.
105108 if (!(removed_val & (1ULL << inblock))) {
106109 if (thread_id == 0 ) {
@@ -109,7 +112,8 @@ __global__ static void gather_keep_from_mask(bool *keep,
109112 auto p = dev_mask + i * col_blocks;
110113 // Remove all bboxes which overlap the candidate.
111114 for (int j = thread_id; j < col_blocks; j += blockDim .x ) {
112- if (j >= nblock) removed[j] |= p[j];
115+ if (j >= nblock)
116+ removed[j] |= p[j];
113117 }
114118 __syncthreads ();
115119 removed_val = removed[nblock];
@@ -174,19 +178,21 @@ at::Tensor nms_kernel(
174178 (unsigned long long *)mask.data_ptr <int64_t >());
175179 });
176180
177- at::Tensor keep = at::zeros (
178- {dets_num},
179- dets.options ().dtype (at::kBool ).device (at::kCUDA )
180- );
181+ at::Tensor keep =
182+ at::zeros ({dets_num}, dets.options ().dtype (at::kBool ).device (at::kCUDA ));
181183
182184 // Unwrap the mask to fill keep with proper values
183- // Keeping the unwrap on device instead of applying iterative for loops on cpu
185+ // Keeping the unwrap on device instead of applying iterative for loops on cpu
184186 // prevents the device -> cpu -> device transfer that could be bottleneck for
185187 // large number of boxes.
186188 // See https://github.com/pytorch/vision/issues/8713 for more details.
187- gather_keep_from_mask<<<1 , min(col_blocks, threadsPerBlock),
188- col_blocks * sizeof (unsigned long long ), stream>>> (
189- keep.data_ptr <bool >(), (unsigned long long *)mask.data_ptr <int64_t >(),
189+ gather_keep_from_mask<<<
190+ 1 ,
191+ min (col_blocks, threadsPerBlock),
192+ col_blocks * sizeof(unsigned long long ),
193+ stream>>>(
194+ keep.data_ptr<bool >(),
195+ (unsigned long long *)mask.data_ptr<int64_t>(),
190196 dets_num);
191197
192198 AT_CUDA_CHECK (cudaGetLastError());
0 commit comments