Skip to content

Commit 2f5740b

Browse files
authored
[NMS] Add NMS f32 cuda kernel. (#102)
1 parent 6c89595 commit 2f5740b

File tree

3 files changed

+230
-1
lines changed

3 files changed

+230
-1
lines changed

nms/README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# NMS
2+
3+
## 0x00 说明
4+
5+
包含以下内容:
6+
7+
- [X] nms_kernel(CPU/GPU)
8+
- [X] PyTorch bindings
9+
10+
nms cuda实现是最基础的版本,根据[官方源码](https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cuda/nms_kernel.cu)可以进行进一步优化。
11+
12+
## 测试
13+
14+
```bash
15+
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
16+
export TORCH_CUDA_ARCH_LIST=Ada
17+
python3 nms.py
18+
```
19+
20+
输出:
21+
22+
```bash
23+
-------------------------------------------------------------------------------------
24+
nboxes=1024
25+
out_nms: ['1021 ', '1022 ', '1023 '], len of keep: 950, time:0.26456594ms
26+
out_nms_th: ['1021 ', '1022 ', '1023 '], len of keep: 950, time:0.19218683ms
27+
-------------------------------------------------------------------------------------
28+
-------------------------------------------------------------------------------------
29+
nboxes=2048
30+
out_nms: ['2045 ', '2046 ', '2047 '], len of keep: 1838, time:0.47256470ms
31+
out_nms_th: ['2044 ', '2045 ', '2047 '], len of keep: 1838, time:0.39437532ms
32+
-------------------------------------------------------------------------------------
33+
-------------------------------------------------------------------------------------
34+
nboxes=4096
35+
out_nms: ['4092 ', '4093 ', '4095 '], len of keep: 3598, time:0.89909315ms
36+
out_nms_th: ['4093 ', '4094 ', '4095 '], len of keep: 3598, time:1.03515625ms
37+
-------------------------------------------------------------------------------------
38+
-------------------------------------------------------------------------------------
39+
nboxes=8192
40+
out_nms: ['8189 ', '8190 ', '8191 '], len of keep: 7023, time:1.49935722ms
41+
out_nms_th: ['8189 ', '8190 ', '8191 '], len of keep: 7023, time:3.39094877ms
42+
-------------------------------------------------------------------------------------
43+
```

nms/nms.cu

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,103 @@
1-
// TODO: CUDA NMS
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <float.h>
4+
#include <vector>
5+
#include <algorithm>
6+
#include <cuda_runtime.h>
7+
#include <cuda_fp16.h>
8+
#include <torch/types.h>
9+
#include <torch/extension.h>
10+
11+
#define WARP_SIZE 32
12+
#define INT4(value) (reinterpret_cast<int4 *>(&(value))[0])
13+
#define FLOAT4(value) (reinterpret_cast<float4 *>(&(value))[0])
14+
15+
__global__ void nms_kernel(const float *boxes, const float *scores, int *keep, int num_boxes, float iou_threshold) {
16+
const int threadsPerBlock = blockDim.x;
17+
const int threadId = threadIdx.x;
18+
const int blockId = blockIdx.x;
19+
const int idx = blockId * threadsPerBlock + threadId;
20+
21+
if (idx >= num_boxes)
22+
return;
23+
24+
float x1 = boxes[idx * 4 + 0];
25+
float y1 = boxes[idx * 4 + 1];
26+
float x2 = boxes[idx * 4 + 2];
27+
float y2 = boxes[idx * 4 + 3];
28+
int suppressed = 0;
29+
30+
for (int i = 0; i < idx; ++i) {
31+
if (keep[i] == 0)
32+
continue;
33+
34+
float x1_i = boxes[i * 4 + 0];
35+
float y1_i = boxes[i * 4 + 1];
36+
float x2_i = boxes[i * 4 + 2];
37+
float y2_i = boxes[i * 4 + 3];
38+
39+
float inter_x1 = max(x1, x1_i);
40+
float inter_y1 = max(y1, y1_i);
41+
float inter_x2 = min(x2, x2_i);
42+
float inter_y2 = min(y2, y2_i);
43+
float inter_w = max(0.0f, inter_x2 - inter_x1);
44+
float inter_h = max(0.0f, inter_y2 - inter_y1);
45+
float inter_area = inter_w * inter_h;
46+
47+
float area = (x2 - x1) * (y2 - y1);
48+
float area_i = (x2_i - x1_i) * (y2_i - y1_i);
49+
float iou = inter_area / (area + area_i - inter_area);
50+
51+
if (iou > iou_threshold) {
52+
keep[idx] = 0;
53+
return;
54+
}
55+
}
56+
keep[idx] = 1;
57+
return;
58+
}
59+
60+
// --------------------- PyTorch bindings for custom kernel -----------------------
61+
#define STRINGFY(str) #str
62+
#define TORCH_BINDING_COMMON_EXTENSION(func) \
63+
m.def(STRINGFY(func), &func, STRINGFY(func));
64+
65+
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
66+
if (((T).options().dtype() != (th_type))) { \
67+
std::cout << "Tensor Info:" << (T).options() << std::endl; \
68+
throw std::runtime_error("values must be " #th_type); \
69+
}
70+
71+
torch::Tensor nms(torch::Tensor boxes, torch::Tensor scores, float iou_threshold) {
72+
CHECK_TORCH_TENSOR_DTYPE(boxes, torch::kFloat32);
73+
CHECK_TORCH_TENSOR_DTYPE(scores, torch::kFloat32);
74+
const int num_boxes = boxes.size(0);
75+
auto toption = torch::TensorOptions().dtype(torch::kInt32).device(boxes.device());
76+
auto keep = torch::empty({boxes.size(0)}, toption);
77+
dim3 block(WARP_SIZE);
78+
dim3 grid((num_boxes + WARP_SIZE - 1) / WARP_SIZE);
79+
// sort boxes by scores
80+
auto order_t = std::get<1>(
81+
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
82+
auto boxes_sorted = boxes.index_select(0, order_t).contiguous();
83+
84+
nms_kernel<<<grid, block>>>(
85+
reinterpret_cast<float *>(boxes_sorted.data_ptr()),
86+
reinterpret_cast<float *>(scores.data_ptr()),
87+
reinterpret_cast<int *>(keep.data_ptr()),
88+
num_boxes, iou_threshold);
89+
auto keep_cpu = keep.to(torch::kCPU);
90+
91+
std::vector<int> keep_indices;
92+
auto keep_accessor = keep_cpu.accessor<int, 1>();
93+
for (int i = 0; i < num_boxes; ++i) {
94+
if (keep_accessor[i] == 1) {
95+
keep_indices.push_back(i);
96+
}
97+
}
98+
return torch::tensor(keep_indices, torch::TensorOptions().dtype(torch::kInt32));
99+
}
100+
101+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
102+
TORCH_BINDING_COMMON_EXTENSION(nms)
103+
}

nms/nms.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import torch
2+
import time
3+
from torch.utils.cpp_extension import load
4+
from typing import Optional
5+
from functools import partial
6+
from torchvision.ops import nms
7+
torch.set_grad_enabled(False)
8+
9+
# Load the CUDA kernel as a python module
10+
lib = load(
11+
name="nms_lib",
12+
sources=["nms.cu"],
13+
extra_cuda_cflags=[
14+
"-O3",
15+
"-U__CUDA_NO_HALF_OPERATORS__",
16+
"-U__CUDA_NO_HALF_CONVERSIONS__",
17+
"-U__CUDA_NO_HALF2_OPERATORS__",
18+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
19+
"--expt-relaxed-constexpr",
20+
"--expt-extended-lambda",
21+
"--use_fast_math",
22+
],
23+
extra_cflags=["-std=c++17"],
24+
)
25+
26+
27+
def generate_random_data(Nboxes):
28+
boxes = torch.rand(Nboxes, 4)
29+
for i in range(Nboxes):
30+
if boxes[i, 0] > boxes[i, 2]:
31+
boxes[i, 0], boxes[i, 2] = boxes[i, 2], boxes[i, 0]
32+
if boxes[i, 1] > boxes[i, 3]:
33+
boxes[i, 1], boxes[i, 3] = boxes[i, 3], boxes[i, 1]
34+
scores = torch.rand(Nboxes)
35+
return boxes, scores
36+
37+
38+
def run_benchmark(
39+
perf_func: callable,
40+
scores: torch.Tensor,
41+
boxes: torch.Tensor,
42+
thresholds: float,
43+
tag: str,
44+
warmup: int = 10,
45+
iters: int = 100,
46+
show_all: bool = False,
47+
):
48+
# warmup
49+
for i in range(warmup):
50+
out = perf_func(scores, boxes, thresholds)
51+
torch.cuda.synchronize()
52+
53+
start = time.time()
54+
# iters
55+
for i in range(iters):
56+
out = perf_func(scores, boxes, thresholds)
57+
torch.cuda.synchronize()
58+
end = time.time()
59+
total_time = (end - start) * 1000 # ms
60+
mean_time = total_time / iters
61+
out_info = f"{tag}"
62+
out_val = sorted(out.flatten().detach().cpu().numpy().tolist())
63+
len_val = len(out_val)
64+
out_val = out_val[-min(3, len_val) :]
65+
out_val = [f"{v:<5}" for v in out_val]
66+
print(f"{out_info:>14}: {out_val}, len of keep: {len_val}, time:{mean_time:.8f}ms")
67+
if show_all:
68+
print(out)
69+
return out, mean_time
70+
71+
72+
Nboxes = [1024, 2048, 4096, 8192]
73+
thresholds = 0.5
74+
75+
76+
for nboxes in Nboxes:
77+
print("-" * 85)
78+
print(" " * 40 + f"nboxes={nboxes}")
79+
boxes, scores = generate_random_data(nboxes)
80+
boxes = boxes.cuda().float().contiguous()
81+
scores = scores.cuda().float().contiguous()
82+
run_benchmark(lib.nms, boxes, scores, thresholds, "nms")
83+
run_benchmark(nms, boxes, scores, thresholds, "nms_th")
84+
print("-" * 85)

0 commit comments

Comments
 (0)