1- // TODO: CUDA NMS
1+ #include < stdio.h>
2+ #include < stdlib.h>
3+ #include < float.h>
4+ #include < vector>
5+ #include < algorithm>
6+ #include < cuda_runtime.h>
7+ #include < cuda_fp16.h>
8+ #include < torch/types.h>
9+ #include < torch/extension.h>
10+
11+ #define WARP_SIZE 32
12+ #define INT4 (value ) (reinterpret_cast <int4 *>(&(value))[0 ])
13+ #define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
14+
15+ __global__ void nms_kernel (const float *boxes, const float *scores, int *keep, int num_boxes, float iou_threshold) {
16+ const int threadsPerBlock = blockDim .x ;
17+ const int threadId = threadIdx .x ;
18+ const int blockId = blockIdx .x ;
19+ const int idx = blockId * threadsPerBlock + threadId;
20+
21+ if (idx >= num_boxes)
22+ return ;
23+
24+ float x1 = boxes[idx * 4 + 0 ];
25+ float y1 = boxes[idx * 4 + 1 ];
26+ float x2 = boxes[idx * 4 + 2 ];
27+ float y2 = boxes[idx * 4 + 3 ];
28+ int suppressed = 0 ;
29+
30+ for (int i = 0 ; i < idx; ++i) {
31+ if (keep[i] == 0 )
32+ continue ;
33+
34+ float x1_i = boxes[i * 4 + 0 ];
35+ float y1_i = boxes[i * 4 + 1 ];
36+ float x2_i = boxes[i * 4 + 2 ];
37+ float y2_i = boxes[i * 4 + 3 ];
38+
39+ float inter_x1 = max (x1, x1_i);
40+ float inter_y1 = max (y1, y1_i);
41+ float inter_x2 = min (x2, x2_i);
42+ float inter_y2 = min (y2, y2_i);
43+ float inter_w = max (0 .0f , inter_x2 - inter_x1);
44+ float inter_h = max (0 .0f , inter_y2 - inter_y1);
45+ float inter_area = inter_w * inter_h;
46+
47+ float area = (x2 - x1) * (y2 - y1);
48+ float area_i = (x2_i - x1_i) * (y2_i - y1_i);
49+ float iou = inter_area / (area + area_i - inter_area);
50+
51+ if (iou > iou_threshold) {
52+ keep[idx] = 0 ;
53+ return ;
54+ }
55+ }
56+ keep[idx] = 1 ;
57+ return ;
58+ }
59+
60+ // --------------------- PyTorch bindings for custom kernel -----------------------
61+ #define STRINGFY (str ) #str
62+ #define TORCH_BINDING_COMMON_EXTENSION (func ) \
63+ m.def(STRINGFY(func), &func, STRINGFY(func));
64+
65+ #define CHECK_TORCH_TENSOR_DTYPE (T, th_type ) \
66+ if (((T).options().dtype() != (th_type))) { \
67+ std::cout << " Tensor Info:" << (T).options () << std::endl; \
68+ throw std::runtime_error (" values must be " #th_type); \
69+ }
70+
71+ torch::Tensor nms (torch::Tensor boxes, torch::Tensor scores, float iou_threshold) {
72+ CHECK_TORCH_TENSOR_DTYPE (boxes, torch::kFloat32 );
73+ CHECK_TORCH_TENSOR_DTYPE (scores, torch::kFloat32 );
74+ const int num_boxes = boxes.size (0 );
75+ auto toption = torch::TensorOptions ().dtype (torch::kInt32 ).device (boxes.device ());
76+ auto keep = torch::empty ({boxes.size (0 )}, toption);
77+ dim3 block (WARP_SIZE);
78+ dim3 grid ((num_boxes + WARP_SIZE - 1 ) / WARP_SIZE);
79+ // sort boxes by scores
80+ auto order_t = std::get<1 >(
81+ scores.sort (/* stable=*/ true , /* dim=*/ 0 , /* descending=*/ true ));
82+ auto boxes_sorted = boxes.index_select (0 , order_t ).contiguous ();
83+
84+ nms_kernel<<<grid, block>>> (
85+ reinterpret_cast <float *>(boxes_sorted.data_ptr ()),
86+ reinterpret_cast <float *>(scores.data_ptr ()),
87+ reinterpret_cast <int *>(keep.data_ptr ()),
88+ num_boxes, iou_threshold);
89+ auto keep_cpu = keep.to (torch::kCPU );
90+
91+ std::vector<int > keep_indices;
92+ auto keep_accessor = keep_cpu.accessor <int , 1 >();
93+ for (int i = 0 ; i < num_boxes; ++i) {
94+ if (keep_accessor[i] == 1 ) {
95+ keep_indices.push_back (i);
96+ }
97+ }
98+ return torch::tensor (keep_indices, torch::TensorOptions ().dtype (torch::kInt32 ));
99+ }
100+
101+ PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
102+ TORCH_BINDING_COMMON_EXTENSION (nms)
103+ }
0 commit comments