55#include " app_scrfd/scrfd.hpp"
66#include " app_arcface/arcface.hpp"
77#include " tools/deepsort.hpp"
8- // #include "tools/zmq_remote_show.hpp"
8+ #include " tools/zmq_remote_show.hpp"
99
1010using namespace std ;
1111using namespace cv ;
@@ -19,6 +19,9 @@ static bool compile_models(){
1919 TRT::set_device (0 );
2020 string model_file;
2121
22+ if (!compile_retinaface (640 , 480 , model_file))
23+ return false ;
24+
2225 if (!compile_scrfd (640 , 480 , model_file))
2326 return false ;
2427
@@ -181,7 +184,7 @@ int app_arcface_video(){
181184 // auto remote_show = create_zmq_remote_show();
182185 INFO (" Use tools/show.py to remote show" );
183186
184- VideoCapture cap (" exp/WIN_20210425_14_23_24_Pro .mp4" );
187+ VideoCapture cap (" exp/face_tracker .mp4" );
185188 Mat image;
186189 while (cap.read (image)){
187190 auto faces = detector->commit (image).get ();
@@ -223,6 +226,46 @@ int app_arcface_video(){
223226 return 0 ;
224227}
225228
229+ class MotionFilter {
230+ public:
231+ MotionFilter (){
232+ location_.left = location_.top = location_.right = location_.bottom = 0 ;
233+ }
234+
235+ void missed (){
236+ init_ = false ;
237+ }
238+
239+ void update (const DeepSORT::Box& box){
240+
241+ const float a[] = {box.left , box.top , box.right , box.bottom };
242+ const float b[] = {location_.left , location_.top , location_.right , location_.bottom };
243+
244+ if (!init_){
245+ init_ = true ;
246+ location_ = box;
247+ return ;
248+ }
249+
250+ float v[4 ];
251+ for (int i = 0 ; i < 4 ; ++i)
252+ v[i] = a[i] * 0.6 + b[i] * 0.4 ;
253+
254+ location_.left = v[0 ];
255+ location_.top = v[1 ];
256+ location_.right = v[2 ];
257+ location_.bottom = v[3 ];
258+ }
259+
260+ DeepSORT::Box predict (){
261+ return location_;
262+ }
263+
264+ private:
265+ DeepSORT::Box location_;
266+ bool init_ = false ;
267+ };
268+
226269int app_arcface_tracker (){
227270
228271 TRT::set_device (0 );
@@ -234,7 +277,7 @@ int app_arcface_tracker(){
234277 auto detector = Scrfd::create_infer (" scrfd_2.5g_bnkps.640x480.FP32.trtmodel" , 0 , 0 .6f );
235278 // auto detector = RetinaFace::create_infer("mb_retinaface.640x480.FP32.trtmodel", 0, 0.6f);
236279 auto arcface = Arcface::create_infer (" arcface_iresnet50.fp32.trtmodel" , 0 );
237- auto library = build_library (detector, arcface);
280+ // auto library = build_library(detector, arcface);
238281
239282 // tools/show.py connect to remote show
240283 // auto remote_show = create_zmq_remote_show();
@@ -252,10 +295,10 @@ int app_arcface_tracker(){
252295 });
253296
254297 auto tracker = DeepSORT::create_tracker (config);
255- VideoCapture cap (" exp/face_tracker1 .mp4" );
298+ VideoCapture cap (" exp/face_tracker .mp4" );
256299 Mat image;
257300
258- VideoWriter writer (" tracker.result.avi " , cv::VideoWriter::fourcc (' X ' , ' V ' , ' I ' , ' D ' ),
301+ VideoWriter writer (" tracker.result.mp4 " , cv::VideoWriter::fourcc (' a ' , ' v ' , ' c ' , ' 1 ' ),
259302 cap.get (cv::CAP_PROP_FPS),
260303 Size (cap.get (cv::CAP_PROP_FRAME_WIDTH), cap.get (cv::CAP_PROP_FRAME_HEIGHT))
261304 );
@@ -264,51 +307,60 @@ int app_arcface_tracker(){
264307 return 0 ;
265308 }
266309
310+ unordered_map<int , MotionFilter> MotionFilter;
267311 while (cap.read (image)){
268312 auto faces = detector->commit (image).get ();
269313 vector<string> names (faces.size ());
270314 vector<DeepSORT::Box> boxes;
271315 for (int i = 0 ; i < faces.size (); ++i){
272316 auto & face = faces[i];
317+ if (max (face.width (), face.height ()) < 30 ) continue ;
318+
273319 auto crop = detector->crop_face_and_landmark (image, face);
274320 auto track_box = DeepSORT::convert_to_box (face);
275321
276322 Arcface::landmarks landmarks;
277323 memcpy (landmarks.points , get<1 >(crop).landmark , sizeof (landmarks.points ));
278324
279325 track_box.feature = arcface->commit (make_tuple (get<0 >(crop), landmarks)).get ();
280- Mat scores = get<0 >(library) * track_box.feature .t ();
281- float * pscore = scores.ptr <float >(0 );
282- int label = std::max_element (pscore, pscore + scores.rows ) - pscore;
283- float match_score = max (0 .0f , pscore[label]);
326+ // Mat scores = get<0>(library) * track_box.feature.t();
327+ // float* pscore = scores.ptr<float>(0);
328+ // int label = std::max_element(pscore, pscore + scores.rows) - pscore;
329+ // float match_score = max(0.0f, pscore[label]);
284330 boxes.emplace_back (std::move (track_box));
285331
286- if (match_score > 0 .3f ){
287- names[i] = iLogger::format (" %s[%.3f]" , get<1 >(library)[label].c_str (), match_score);
288- }
332+ // if(match_score > 0.3f){
333+ // names[i] = iLogger::format("%s[%.3f]", get<1>(library)[label].c_str(), match_score);
334+ // }
289335 }
290336 tracker->update (boxes);
291337
292338 auto final_objects = tracker->get_objects ();
293339 for (int i = 0 ; i < final_objects.size (); ++i){
294340 auto & person = final_objects[i];
295- if (person->time_since_update () == 0 && person->state () == DeepSORT::State::Confirmed){
296- Rect box = DeepSORT::convert_box_to_rect (person->last_position ());
297-
298-
299- // auto line = person->trace_line();
300- // for(int j = 0; j < (int)line.size() - 1; ++j){
301- // auto& p = line[j];
302- // auto& np = line[j + 1];
303- // cv::line(image, p, np, Scalar(255, 128, 60), 2, 16);
304- // }
341+ auto & filter = MotionFilter[person->id ()];
305342
343+ if (person->time_since_update () == 0 && person->state () == DeepSORT::State::Confirmed){
306344 uint8_t r, g, b;
307345 std::tie (r, g, b) = iLogger::random_color (person->id ());
308346
309- rectangle (image, DeepSORT::convert_box_to_rect (person->predict_box ()), Scalar (0 , 255 , 0 ), 2 );
310- rectangle (image, box, Scalar (b, g, r), 3 );
311- putText (image, iLogger::format (" %d" , person->id ()), Point (box.x , box.y -10 ), 0 , 1 , Scalar (b, g, r), 2 , 16 );
347+ auto loaction = person->last_position ();
348+ filter.update (loaction);
349+ loaction = filter.predict ();
350+
351+ const int shift = 4 , shv = 1 << shift;
352+ rectangle (image,
353+ Point (loaction.left * shv, loaction.top * shv),
354+ Point (loaction.right * shv, loaction.bottom * shv),
355+ Scalar (b, g, r), 3 , 16 , shift
356+ );
357+
358+ putText (image, iLogger::format (" %d" , person->id ()),
359+ Point (loaction.left , loaction.top - 10 ),
360+ 0 , 2 , Scalar (b, g, r), 3 , 16
361+ );
362+ }else {
363+ filter.missed ();
312364 }
313365 }
314366
@@ -323,7 +375,7 @@ int app_arcface_tracker(){
323375 // putText(image, names[i], cv::Point(face.left + 30, face.top - 10), 0, 1, color, 2, 16);
324376 // }
325377 // remote_show->post(image);
326- // writer.write(image);
378+ writer.write (image);
327379 }
328380 INFO (" Done" );
329381 return 0 ;
0 commit comments