@@ -77,6 +77,51 @@ __global__ void nms_kernel_impl(
77
77
}
78
78
}
79
79
80
+ __global__ static void gather_keep_from_mask (
81
+ bool * keep,
82
+ const unsigned long long * dev_mask,
83
+ const int n_boxes) {
84
+ // Taken and adapted from mmcv
85
+ // https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
86
+ const int col_blocks = ceil_div (n_boxes, threadsPerBlock);
87
+ const int thread_id = threadIdx .x ;
88
+
89
+ // Mark the bboxes which have been removed.
90
+ extern __shared__ unsigned long long removed[];
91
+
92
+ // Initialize removed.
93
+ for (int i = thread_id; i < col_blocks; i += blockDim .x ) {
94
+ removed[i] = 0 ;
95
+ }
96
+ __syncthreads ();
97
+
98
+ for (int nblock = 0 ; nblock < col_blocks; nblock++) {
99
+ auto removed_val = removed[nblock];
100
+ __syncthreads ();
101
+ const int i_offset = nblock * threadsPerBlock;
102
+ #pragma unroll
103
+ for (int inblock = 0 ; inblock < threadsPerBlock; inblock++) {
104
+ const int i = i_offset + inblock;
105
+ if (i >= n_boxes)
106
+ break ;
107
+ // Select a candidate, check if it should kept.
108
+ if (!(removed_val & (1ULL << inblock))) {
109
+ if (thread_id == 0 ) {
110
+ keep[i] = true ;
111
+ }
112
+ auto p = dev_mask + i * col_blocks;
113
+ // Remove all bboxes which overlap the candidate.
114
+ for (int j = thread_id; j < col_blocks; j += blockDim .x ) {
115
+ if (j >= nblock)
116
+ removed[j] |= p[j];
117
+ }
118
+ __syncthreads ();
119
+ removed_val = removed[nblock];
120
+ }
121
+ }
122
+ }
123
+ }
124
+
80
125
at::Tensor nms_kernel (
81
126
const at::Tensor& dets,
82
127
const at::Tensor& scores,
@@ -133,35 +178,25 @@ at::Tensor nms_kernel(
133
178
(unsigned long long *)mask.data_ptr <int64_t >());
134
179
});
135
180
136
- at::Tensor mask_cpu = mask.to (at::kCPU );
137
- unsigned long long * mask_host =
138
- (unsigned long long *)mask_cpu.data_ptr <int64_t >();
139
-
140
- std::vector<unsigned long long > remv (col_blocks);
141
- memset (&remv[0 ], 0 , sizeof (unsigned long long ) * col_blocks);
142
-
143
181
at::Tensor keep =
144
- at::empty ({dets_num}, dets.options ().dtype (at::kLong ).device (at::kCPU ));
145
- int64_t * keep_out = keep.data_ptr <int64_t >();
146
-
147
- int num_to_keep = 0 ;
148
- for (int i = 0 ; i < dets_num; i++) {
149
- int nblock = i / threadsPerBlock;
150
- int inblock = i % threadsPerBlock;
151
-
152
- if (!(remv[nblock] & (1ULL << inblock))) {
153
- keep_out[num_to_keep++] = i;
154
- unsigned long long * p = mask_host + i * col_blocks;
155
- for (int j = nblock; j < col_blocks; j++) {
156
- remv[j] |= p[j];
157
- }
158
- }
159
- }
182
+ at::zeros ({dets_num}, dets.options ().dtype (at::kBool ).device (at::kCUDA ));
183
+
184
+ // Unwrap the mask to fill keep with proper values
185
+ // Keeping the unwrap on device instead of applying iterative for loops on cpu
186
+ // prevents the device -> cpu -> device transfer that could be bottleneck for
187
+ // large number of boxes.
188
+ // See https://github.com/pytorch/vision/issues/8713 for more details.
189
+ gather_keep_from_mask<<<
190
+ 1 ,
191
+ min (col_blocks, threadsPerBlock),
192
+ col_blocks * sizeof(unsigned long long ),
193
+ stream>>>(
194
+ keep.data_ptr<bool >(),
195
+ (unsigned long long *)mask.data_ptr<int64_t>(),
196
+ dets_num);
160
197
161
198
AT_CUDA_CHECK (cudaGetLastError());
162
- return order_t .index (
163
- {keep.narrow (/* dim=*/ 0 , /* start=*/ 0 , /* length=*/ num_to_keep)
164
- .to (order_t .device (), keep.scalar_type ())});
199
+ return order_t .masked_select(keep);
165
200
}
166
201
167
202
} // namespace
0 commit comments