1+ //
2+ // Created by wangzijian on 11/7/24.
3+ //
4+ #include " lite/lite.h"
5+
6+ #include " lite/trt/cv/trt_face_restoration_mt.h"
7+
8+ static void test_default ()
9+ {
10+ #ifdef ENABLE_ONNXRUNTIME
11+ std::string onnx_path = " /home/lite.ai.toolkit/examples/hub/onnx/cv/gfpgan_1.4.onnx" ;
12+ std::string test_img_path = " /home/lite.ai.toolkit/trt_result.jpg" ;
13+ std::string save_img_path = " /home/lite.ai.toolkit/trt_result_final.jpg" ;
14+
15+ // 1. Test Default Engine ONNXRuntime
16+ lite::cv::face::restoration::GFPGAN *face_restoration = new lite::cv::face::restoration::GFPGAN (onnx_path);
17+
18+ std::vector<cv::Point2f> face_landmark_5 = {
19+ cv::Point2f (569 .092041f , 398 .845886f ),
20+ cv::Point2f (701 .891724f , 399 .156677f ),
21+ cv::Point2f (634 .767212f , 482 .927216f ),
22+ cv::Point2f (584 .270996f , 543 .294617f ),
23+ cv::Point2f (684 .877991f , 543 .067078f )
24+ };
25+ cv::Mat img_bgr = cv::imread (test_img_path);
26+
27+ face_restoration->detect (img_bgr,face_landmark_5,save_img_path);
28+
29+
30+ std::cout<<" face restoration detect done!" <<std::endl;
31+
32+ delete face_restoration;
33+ #endif
34+ }
35+
36+
37+
38+
39+ static void test_tensorrt ()
40+ {
41+ #ifdef ENABLE_TENSORRT
42+ std::string engine_path = " /home/lite.ai.toolkit/examples/hub/trt/gfpgan_1.4_fp32.engine" ;
43+ std::string test_img_path = " /home/lite.ai.toolkit/trt_result.jpg" ;
44+ std::string save_img_path = " /home/lite.ai.toolkit/trt_facerestoration_mt_test111.jpg" ;
45+
46+ // 1. Test Default Engine TensorRT
47+ // lite::trt::cv::face::restoration::TRTGFPGAN *face_restoration_trt = new lite::trt::cv::face::restoration::TRTGFPGAN(engine_path);
48+
49+ const int num_threads = 4 ; // 使用4个线程
50+ auto face_restoration_trt = std::make_unique<trt_face_restoration_mt>(engine_path,4 );
51+
52+ // trt_face_restoration_mt *face_restoration_trt = new trt_face_restoration_mt(engine_path);
53+
54+
55+ // 2. 准备测试数据 - 这里假设我们要处理4张相同的图片作为示例
56+ std::vector<std::string> test_img_paths = {
57+ " /home/lite.ai.toolkit/trt_result.jpg" ,
58+ " /home/lite.ai.toolkit/trt_result_2.jpg" ,
59+ " /home/lite.ai.toolkit/trt_result_3.jpg" ,
60+ " /home/lite.ai.toolkit/trt_result_4.jpg"
61+ };
62+
63+ std::vector<std::string> save_img_paths = {
64+ " /home/lite.ai.toolkit/trt_facerestoration_mt_thread1.jpg" ,
65+ " /home/lite.ai.toolkit/trt_facerestoration_mt_thread2.jpg" ,
66+ " /home/lite.ai.toolkit/trt_facerestoration_mt_thread3.jpg" ,
67+ " /home/lite.ai.toolkit/trt_facerestoration_mt_thread4.jpg"
68+ };
69+
70+ std::vector<cv::Point2f> face_landmark_5 = {
71+ cv::Point2f (569 .092041f , 398 .845886f ),
72+ cv::Point2f (701 .891724f , 399 .156677f ),
73+ cv::Point2f (634 .767212f , 482 .927216f ),
74+ cv::Point2f (584 .270996f , 543 .294617f ),
75+ cv::Point2f (684 .877991f , 543 .067078f )
76+ };
77+ // cv::Mat img_bgr = cv::imread(test_img_path);
78+ //
79+ // face_restoration_trt->detect_async(img_bgr,face_landmark_5,save_img_path);
80+ //
81+ //
82+ // std::cout<<"face restoration detect done!"<<std::endl;
83+ //
84+ // delete face_restoration_trt;
85+ auto start_time = std::chrono::high_resolution_clock::now ();
86+
87+ for (size_t i=0 ; i < test_img_paths.size ();++i){
88+ cv::Mat img_bgr = cv::imread (test_img_paths[i]);
89+ if (img_bgr.empty ()) {
90+ std::cerr << " Failed to read image: " << test_img_paths[i] << std::endl;
91+ continue ;
92+ }
93+ // 异步提交任务
94+ face_restoration_trt->detect_async (img_bgr, face_landmark_5, save_img_paths[i]);
95+ std::cout << " Submitted task " << i + 1 << " for processing" << std::endl;
96+ }
97+
98+ // 6. 等待所有任务完成
99+ std::cout << " Waiting for all tasks to complete..." << std::endl;
100+ face_restoration_trt->wait_for_completion ();
101+
102+ // 7. 计算和输出总耗时
103+ auto end_time = std::chrono::high_resolution_clock::now ();
104+ auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
105+
106+ std::cout << " All tasks completed!" << std::endl;
107+ std::cout << " Total processing time: " << duration.count () << " ms" << std::endl;
108+ std::cout << " Average time per image: " << duration.count () / test_img_paths.size () << " ms" << std::endl;
109+
110+
111+ #endif
112+ }
113+
114+ int main (__unused int argc, __unused char *argv[])
115+ {
116+ // test_default();
117+ test_tensorrt ();
118+ return 0 ;
119+ }
0 commit comments