Skip to content

Commit 53cef58

Browse files
committed
clang-format
1 parent 2c6edf4 commit 53cef58

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

torchvision/csrc/ops/cuda/nms_kernel.cu

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,12 @@ __global__ void nms_kernel_impl(
7777
}
7878
}
7979

80-
__global__ static void gather_keep_from_mask(bool *keep,
81-
const unsigned long long *dev_mask,
82-
const int n_boxes) {
83-
// Taken and adapted from mmcv https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
80+
__global__ static void gather_keep_from_mask(
81+
bool* keep,
82+
const unsigned long long* dev_mask,
83+
const int n_boxes) {
84+
// Taken and adapted from mmcv
85+
// https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
8486
const int col_blocks = ceil_div(n_boxes, threadsPerBlock);
8587
const int thread_id = threadIdx.x;
8688

@@ -97,10 +99,11 @@ __global__ static void gather_keep_from_mask(bool *keep,
9799
auto removed_val = removed[nblock];
98100
__syncthreads();
99101
const int i_offset = nblock * threadsPerBlock;
100-
#pragma unroll
102+
#pragma unroll
101103
for (int inblock = 0; inblock < threadsPerBlock; inblock++) {
102104
const int i = i_offset + inblock;
103-
if (i >= n_boxes) break;
105+
if (i >= n_boxes)
106+
break;
104107
// Select a candidate, check if it should kept.
105108
if (!(removed_val & (1ULL << inblock))) {
106109
if (thread_id == 0) {
@@ -109,7 +112,8 @@ __global__ static void gather_keep_from_mask(bool *keep,
109112
auto p = dev_mask + i * col_blocks;
110113
// Remove all bboxes which overlap the candidate.
111114
for (int j = thread_id; j < col_blocks; j += blockDim.x) {
112-
if (j >= nblock) removed[j] |= p[j];
115+
if (j >= nblock)
116+
removed[j] |= p[j];
113117
}
114118
__syncthreads();
115119
removed_val = removed[nblock];
@@ -174,19 +178,21 @@ at::Tensor nms_kernel(
174178
(unsigned long long*)mask.data_ptr<int64_t>());
175179
});
176180

177-
at::Tensor keep = at::zeros(
178-
{dets_num},
179-
dets.options().dtype(at::kBool).device(at::kCUDA)
180-
);
181+
at::Tensor keep =
182+
at::zeros({dets_num}, dets.options().dtype(at::kBool).device(at::kCUDA));
181183

182184
// Unwrap the mask to fill keep with proper values
183-
// Keeping the unwrap on device instead of applying iterative for loops on cpu
185+
// Keeping the unwrap on device instead of applying iterative for loops on cpu
184186
// prevents the device -> cpu -> device transfer that could be bottleneck for
185187
// large number of boxes.
186188
// See https://github.com/pytorch/vision/issues/8713 for more details.
187-
gather_keep_from_mask<<<1, min(col_blocks, threadsPerBlock),
188-
col_blocks * sizeof(unsigned long long), stream>>>(
189-
keep.data_ptr<bool>(), (unsigned long long*)mask.data_ptr<int64_t>(),
189+
gather_keep_from_mask<<<
190+
1,
191+
min(col_blocks, threadsPerBlock),
192+
col_blocks * sizeof(unsigned long long),
193+
stream>>>(
194+
keep.data_ptr<bool>(),
195+
(unsigned long long*)mask.data_ptr<int64_t>(),
190196
dets_num);
191197
192198
AT_CUDA_CHECK(cudaGetLastError());

0 commit comments

Comments
 (0)