Skip to content

Commit f353f46

Browse files
author
Wish
committed
update
1 parent 40e3faf commit f353f46

23 files changed

+801
-271
lines changed

.vscode/c_cpp_properties.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"/data/sxai/lean/protobuf3.11.4/include/**",
88
"/data/sxai/lean/opencv4.2.0/include/opencv4/**",
99
"/data/sxai/lean/cuda10.2/include/**",
10-
"/data/sxai/lean/TensorRT-8.0.1.6/include/**",
10+
"/data/sxai/lean/TensorRT-8.0.1.6-cuda10.2-cudnn8.2/include/**",
1111
"/data/sxai/lean/cudnn7.6.5.32-cuda10.2/include/**",
1212
"/data/datav/newbb/lean/anaconda3/envs/torch1.8/include/python3.9/**"
1313
],
@@ -23,7 +23,7 @@
2323
"/data/sxai/lean/protobuf3.11.4/include/**",
2424
"/data/sxai/lean/opencv4.2.0/include/opencv4/**",
2525
"/data/sxai/lean/cuda10.2/include/**",
26-
"/data/sxai/lean/TensorRT-8.0.1.6/include/**",
26+
"/data/sxai/lean/TensorRT-8.0.1.6-cuda10.2-cudnn8.2/include/**",
2727
"/data/sxai/lean/cudnn7.6.5.32-cuda10.2/include/**",
2828
"/data/datav/newbb/lean/anaconda3/envs/torch1.8/include/python3.9/**"
2929
],

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"type": "cppdbg",
1010
"request": "launch",
1111
"program": "${workspaceFolder}/workspace/pro",
12-
"args": ["high_perf"],
12+
"args": ["yolo"],
1313
"stopAtEntry": false,
1414
"cwd": "${workspaceFolder}/workspace",
1515
"environment": [],

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ high_perf : workspace/pro
145145
lesson : workspace/pro
146146
@cd workspace && ./pro lesson
147147

148+
plugin : workspace/pro
149+
@cd workspace && ./pro plugin
150+
148151
pytorch : trtpyc
149152
@cd python && python test_torch.py
150153

code.tar.gz

1.48 MB
Binary file not shown.

src/application/app_arcface.cpp

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "app_scrfd/scrfd.hpp"
66
#include "app_arcface/arcface.hpp"
77
#include "tools/deepsort.hpp"
8-
//#include "tools/zmq_remote_show.hpp"
8+
#include "tools/zmq_remote_show.hpp"
99

1010
using namespace std;
1111
using namespace cv;
@@ -19,6 +19,9 @@ static bool compile_models(){
1919
TRT::set_device(0);
2020
string model_file;
2121

22+
if(!compile_retinaface(640, 480, model_file))
23+
return false;
24+
2225
if(!compile_scrfd(640, 480, model_file))
2326
return false;
2427

@@ -181,7 +184,7 @@ int app_arcface_video(){
181184
//auto remote_show = create_zmq_remote_show();
182185
INFO("Use tools/show.py to remote show");
183186

184-
VideoCapture cap("exp/WIN_20210425_14_23_24_Pro.mp4");
187+
VideoCapture cap("exp/face_tracker.mp4");
185188
Mat image;
186189
while(cap.read(image)){
187190
auto faces = detector->commit(image).get();
@@ -223,6 +226,46 @@ int app_arcface_video(){
223226
return 0;
224227
}
225228

229+
class MotionFilter{
230+
public:
231+
MotionFilter(){
232+
location_.left = location_.top = location_.right = location_.bottom = 0;
233+
}
234+
235+
void missed(){
236+
init_ = false;
237+
}
238+
239+
void update(const DeepSORT::Box& box){
240+
241+
const float a[] = {box.left, box.top, box.right, box.bottom};
242+
const float b[] = {location_.left, location_.top, location_.right, location_.bottom};
243+
244+
if(!init_){
245+
init_ = true;
246+
location_ = box;
247+
return;
248+
}
249+
250+
float v[4];
251+
for(int i = 0; i < 4; ++i)
252+
v[i] = a[i] * 0.6 + b[i] * 0.4;
253+
254+
location_.left = v[0];
255+
location_.top = v[1];
256+
location_.right = v[2];
257+
location_.bottom = v[3];
258+
}
259+
260+
DeepSORT::Box predict(){
261+
return location_;
262+
}
263+
264+
private:
265+
DeepSORT::Box location_;
266+
bool init_ = false;
267+
};
268+
226269
int app_arcface_tracker(){
227270

228271
TRT::set_device(0);
@@ -234,7 +277,7 @@ int app_arcface_tracker(){
234277
auto detector = Scrfd::create_infer("scrfd_2.5g_bnkps.640x480.FP32.trtmodel", 0, 0.6f);
235278
//auto detector = RetinaFace::create_infer("mb_retinaface.640x480.FP32.trtmodel", 0, 0.6f);
236279
auto arcface = Arcface::create_infer("arcface_iresnet50.fp32.trtmodel", 0);
237-
auto library = build_library(detector, arcface);
280+
//auto library = build_library(detector, arcface);
238281

239282
//tools/show.py connect to remote show
240283
//auto remote_show = create_zmq_remote_show();
@@ -252,10 +295,10 @@ int app_arcface_tracker(){
252295
});
253296

254297
auto tracker = DeepSORT::create_tracker(config);
255-
VideoCapture cap("exp/face_tracker1.mp4");
298+
VideoCapture cap("exp/face_tracker.mp4");
256299
Mat image;
257300

258-
VideoWriter writer("tracker.result.avi", cv::VideoWriter::fourcc('X', 'V', 'I', 'D'),
301+
VideoWriter writer("tracker.result.mp4", cv::VideoWriter::fourcc('a', 'v', 'c', '1'),
259302
cap.get(cv::CAP_PROP_FPS),
260303
Size(cap.get(cv::CAP_PROP_FRAME_WIDTH), cap.get(cv::CAP_PROP_FRAME_HEIGHT))
261304
);
@@ -264,51 +307,60 @@ int app_arcface_tracker(){
264307
return 0;
265308
}
266309

310+
unordered_map<int, MotionFilter> MotionFilter;
267311
while(cap.read(image)){
268312
auto faces = detector->commit(image).get();
269313
vector<string> names(faces.size());
270314
vector<DeepSORT::Box> boxes;
271315
for(int i = 0; i < faces.size(); ++i){
272316
auto& face = faces[i];
317+
if(max(face.width(), face.height()) < 30) continue;
318+
273319
auto crop = detector->crop_face_and_landmark(image, face);
274320
auto track_box = DeepSORT::convert_to_box(face);
275321

276322
Arcface::landmarks landmarks;
277323
memcpy(landmarks.points, get<1>(crop).landmark, sizeof(landmarks.points));
278324

279325
track_box.feature = arcface->commit(make_tuple(get<0>(crop), landmarks)).get();
280-
Mat scores = get<0>(library) * track_box.feature.t();
281-
float* pscore = scores.ptr<float>(0);
282-
int label = std::max_element(pscore, pscore + scores.rows) - pscore;
283-
float match_score = max(0.0f, pscore[label]);
326+
// Mat scores = get<0>(library) * track_box.feature.t();
327+
// float* pscore = scores.ptr<float>(0);
328+
// int label = std::max_element(pscore, pscore + scores.rows) - pscore;
329+
// float match_score = max(0.0f, pscore[label]);
284330
boxes.emplace_back(std::move(track_box));
285331

286-
if(match_score > 0.3f){
287-
names[i] = iLogger::format("%s[%.3f]", get<1>(library)[label].c_str(), match_score);
288-
}
332+
// if(match_score > 0.3f){
333+
// names[i] = iLogger::format("%s[%.3f]", get<1>(library)[label].c_str(), match_score);
334+
// }
289335
}
290336
tracker->update(boxes);
291337

292338
auto final_objects = tracker->get_objects();
293339
for(int i = 0; i < final_objects.size(); ++i){
294340
auto& person = final_objects[i];
295-
if(person->time_since_update() == 0 && person->state() == DeepSORT::State::Confirmed){
296-
Rect box = DeepSORT::convert_box_to_rect(person->last_position());
297-
298-
299-
// auto line = person->trace_line();
300-
// for(int j = 0; j < (int)line.size() - 1; ++j){
301-
// auto& p = line[j];
302-
// auto& np = line[j + 1];
303-
// cv::line(image, p, np, Scalar(255, 128, 60), 2, 16);
304-
// }
341+
auto& filter = MotionFilter[person->id()];
305342

343+
if(person->time_since_update() == 0 && person->state() == DeepSORT::State::Confirmed){
306344
uint8_t r, g, b;
307345
std::tie(r, g, b) = iLogger::random_color(person->id());
308346

309-
rectangle(image, DeepSORT::convert_box_to_rect(person->predict_box()), Scalar(0, 255, 0), 2);
310-
rectangle(image, box, Scalar(b, g, r), 3);
311-
putText(image, iLogger::format("%d", person->id()), Point(box.x, box.y-10), 0, 1, Scalar(b, g, r), 2, 16);
347+
auto loaction = person->last_position();
348+
filter.update(loaction);
349+
loaction = filter.predict();
350+
351+
const int shift = 4, shv = 1 << shift;
352+
rectangle(image,
353+
Point(loaction.left * shv, loaction.top * shv),
354+
Point(loaction.right * shv, loaction.bottom * shv),
355+
Scalar(b, g, r), 3, 16, shift
356+
);
357+
358+
putText(image, iLogger::format("%d", person->id()),
359+
Point(loaction.left, loaction.top - 10),
360+
0, 2, Scalar(b, g, r), 3, 16
361+
);
362+
}else{
363+
filter.missed();
312364
}
313365
}
314366

@@ -323,7 +375,7 @@ int app_arcface_tracker(){
323375
// putText(image, names[i], cv::Point(face.left + 30, face.top - 10), 0, 1, color, 2, 16);
324376
// }
325377
//remote_show->post(image);
326-
//writer.write(image);
378+
writer.write(image);
327379
}
328380
INFO("Done");
329381
return 0;

src/application/app_lesson.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ void lesson_cache1frame(){
105105

106106
//////////////////基础耗时////////////////////////
107107
{
108-
cv::VideoCapture cap("exp/face_tracker1.mp4");
108+
cv::VideoCapture cap("exp/face_tracker.mp4");
109109
cv::Mat image;
110110
int iframe = 0;
111111
auto t0 = iLogger::timestamp_now_float();
@@ -122,7 +122,7 @@ void lesson_cache1frame(){
122122

123123
//////////////////传统做法////////////////////////
124124
{
125-
cv::VideoCapture cap("exp/face_tracker1.mp4");
125+
cv::VideoCapture cap("exp/face_tracker.mp4");
126126
cv::Mat image;
127127
int iframe = 0;
128128
auto t0 = iLogger::timestamp_now_float();
@@ -147,7 +147,7 @@ void lesson_cache1frame(){
147147

148148
//////////////////优化做法////////////////////////
149149
{
150-
cv::VideoCapture cap("exp/face_tracker1.mp4");
150+
cv::VideoCapture cap("exp/face_tracker.mp4");
151151
shared_future<Yolo::ObjectBoxArray> prev_future;
152152
cv::Mat image;
153153
cv::Mat prev_image;

src/application/app_plugin.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
2+
#include <builder/trt_builder.hpp>
3+
#include <infer/trt_infer.hpp>
4+
#include <common/ilogger.hpp>
5+
#include "app_yolo/yolo.hpp"
6+
7+
using namespace std;
8+
9+
static void test_hswish(TRT::Mode mode){
10+
11+
// The plugin.onnx can be generated by the following code
12+
// cd workspace
13+
// python test_plugin.py
14+
iLogger::set_log_level(iLogger::LogLevel::Verbose);
15+
TRT::set_device(0);
16+
17+
auto mode_name = TRT::mode_string(mode);
18+
auto engine_name = iLogger::format("hswish.plugin.%s.trtmodel", mode_name);
19+
TRT::compile(
20+
mode, 3, "hswish.plugin.onnx", engine_name, {}
21+
);
22+
23+
auto engine = TRT::load_infer(engine_name);
24+
engine->print();
25+
26+
auto input0 = engine->input(0);
27+
auto input1 = engine->input(1);
28+
auto output = engine->output(0);
29+
30+
INFO("offset %d", output->offset(1, 0));
31+
INFO("input0: %s", input0->shape_string());
32+
INFO("input1: %s", input1->shape_string());
33+
INFO("output: %s", output->shape_string());
34+
35+
float input0_val = 0.8;
36+
float input1_val = 2;
37+
input0->set_to(input0_val);
38+
input1->set_to(input1_val);
39+
40+
auto hswish = [](float x){float a = x + 3; a=a<0?0:(a>=6?6:a); return x * a / 6;};
41+
auto sigmoid = [](float x){return 1 / (1 + exp(-x));};
42+
auto relu = [](float x){return max(0.0f, x);};
43+
float output_real = relu(hswish(input0_val) * input1_val);
44+
engine->forward(true);
45+
46+
INFO("output %f, output_real = %f", output->at<float>(0, 0), output_real);
47+
}
48+
49+
static void test_dcnv2(TRT::Mode mode){
50+
51+
// The plugin.onnx can be generated by the following code
52+
// cd workspace
53+
// python test_plugin.py
54+
iLogger::set_log_level(iLogger::LogLevel::Verbose);
55+
TRT::set_device(0);
56+
57+
auto mode_name = TRT::mode_string(mode);
58+
auto engine_name = iLogger::format("dcnv2.plugin.%s.trtmodel", mode_name);
59+
TRT::compile(
60+
mode, 1, "dcnv2.plugin.onnx", engine_name, {}
61+
);
62+
63+
auto engine = TRT::load_infer(engine_name);
64+
engine->print();
65+
66+
auto input0 = engine->input(0);
67+
auto input1 = engine->input(1);
68+
auto output = engine->output(0);
69+
70+
INFO("input0: %s", input0->shape_string());
71+
INFO("input1: %s", input1->shape_string());
72+
INFO("output: %s", output->shape_string());
73+
74+
float input0_val = 1;
75+
float input1_val = 1;
76+
input0->set_to(input0_val);
77+
input1->set_to(input1_val);
78+
engine->forward(true);
79+
80+
for(int i = 0; i < output->count(); ++i)
81+
INFO("output[%d] = %f", i, output->cpu<float>()[i]);
82+
}
83+
84+
int app_plugin(){
85+
86+
//test_hswish(TRT::Mode::FP32);
87+
test_dcnv2(TRT::Mode::FP32);
88+
//stest_plugin(TRT::Mode::FP16);
89+
return 0;
90+
}

0 commit comments

Comments
 (0)