2424#include " FilterData.h"
2525#include " consts.h"
2626#include " obs-utils/obs-utils.h"
27- #include " edgeyolo /utils.hpp"
27+ #include " ort-model /utils.hpp"
2828#include " detect-filter-utils.h"
29+ #include " edgeyolo/edgeyolo_onnxruntime.hpp"
30+ #include " yunet/YuNet.h"
2931
3032#define EXTERNAL_MODEL_SIZE " !!!EXTERNAL_MODEL!!!"
33+ #define FACE_DETECT_MODEL_SIZE " !!!FACE_DETECT!!!"
3134
3235struct detect_filter : public filter_data {};
3336
@@ -325,6 +328,8 @@ obs_properties_t *detect_filter_properties(void *data)
325328 obs_property_list_add_string (model_size, obs_module_text (" SmallFast" ), " small" );
326329 obs_property_list_add_string (model_size, obs_module_text (" Medium" ), " medium" );
327330 obs_property_list_add_string (model_size, obs_module_text (" LargeSlow" ), " large" );
331+ obs_property_list_add_string (model_size, obs_module_text (" FaceDetect" ),
332+ FACE_DETECT_MODEL_SIZE);
328333 obs_property_list_add_string (model_size, obs_module_text (" ExternalModel" ),
329334 EXTERNAL_MODEL_SIZE);
330335
@@ -513,6 +518,9 @@ void detect_filter_update(void *data, obs_data_t *settings)
513518 } else if (newModelSize == " large" ) {
514519 modelFilepath_rawPtr =
515520 obs_module_file (" models/edgeyolo_tiny_lrelu_coco_736x1280.onnx" );
521+ } else if (newModelSize == FACE_DETECT_MODEL_SIZE) {
522+ modelFilepath_rawPtr =
523+ obs_module_file (" models/face_detection_yunet_2023mar.onnx" );
516524 } else if (newModelSize == EXTERNAL_MODEL_SIZE) {
517525 const char *external_model_file =
518526 obs_data_get_string (settings, " external_model_file" );
@@ -580,41 +588,53 @@ void detect_filter_update(void *data, obs_data_t *settings)
580588 obs_log (LOG_ERROR,
581589 " JSON file does not contain 'labels' field" );
582590 tf->isDisabled = true ;
583- tf->edgeyolo .reset ();
591+ tf->onnxruntimemodel .reset ();
584592 return ;
585593 }
586594 } else {
587595 obs_log (LOG_ERROR, " Failed to open JSON file: %s" ,
588596 labelsFilepath.c_str ());
589597 tf->isDisabled = true ;
590- tf->edgeyolo .reset ();
598+ tf->onnxruntimemodel .reset ();
591599 return ;
592600 }
601+ } else if (tf->modelSize == FACE_DETECT_MODEL_SIZE) {
602+ num_classes_ = 1 ;
603+ tf->classNames = yunet::FACE_CLASSES;
593604 }
594605
595606 // Load model
596607 try {
597- if (tf->edgeyolo ) {
598- tf->edgeyolo .reset ();
608+ if (tf->onnxruntimemodel ) {
609+ tf->onnxruntimemodel .reset ();
610+ }
611+ if (tf->modelSize == FACE_DETECT_MODEL_SIZE) {
612+ tf->onnxruntimemodel = std::make_unique<yunet::YuNetONNX>(
613+ tf->modelFilepath , tf->numThreads , 50 , tf->numThreads ,
614+ tf->useGPU , onnxruntime_device_id_,
615+ onnxruntime_use_parallel_, nms_th_, tf->conf_threshold );
616+ } else {
617+ tf->onnxruntimemodel =
618+ std::make_unique<edgeyolo_cpp::EdgeYOLOONNXRuntime>(
619+ tf->modelFilepath , tf->numThreads , num_classes_,
620+ tf->numThreads , tf->useGPU , onnxruntime_device_id_,
621+ onnxruntime_use_parallel_, nms_th_,
622+ tf->conf_threshold );
599623 }
600- tf->edgeyolo = std::make_unique<edgeyolo_cpp::EdgeYOLOONNXRuntime>(
601- tf->modelFilepath , tf->numThreads , tf->numThreads , tf->useGPU ,
602- onnxruntime_device_id_, onnxruntime_use_parallel_, nms_th_,
603- tf->conf_threshold , num_classes_);
604624 // clear error message
605625 obs_data_set_string (settings, " error" , " " );
606626 } catch (const std::exception &e) {
607627 obs_log (LOG_ERROR, " Failed to load model: %s" , e.what ());
608628 // disable filter
609629 tf->isDisabled = true ;
610- tf->edgeyolo .reset ();
630+ tf->onnxruntimemodel .reset ();
611631 return ;
612632 }
613633 }
614634
615635 // update threshold on edgeyolo
616- if (tf->edgeyolo ) {
617- tf->edgeyolo ->setBBoxConfThresh (tf->conf_threshold );
636+ if (tf->onnxruntimemodel ) {
637+ tf->onnxruntimemodel ->setBBoxConfThresh (tf->conf_threshold );
618638 }
619639
620640 if (reinitialize) {
@@ -746,7 +766,7 @@ void detect_filter_video_tick(void *data, float seconds)
746766
747767 struct detect_filter *tf = reinterpret_cast <detect_filter *>(data);
748768
749- if (tf->isDisabled || !tf->edgeyolo ) {
769+ if (tf->isDisabled || !tf->onnxruntimemodel ) {
750770 return ;
751771 }
752772
@@ -775,18 +795,16 @@ void detect_filter_video_tick(void *data, float seconds)
775795 cropRect = cv::Rect (tf->crop_left , tf->crop_top ,
776796 imageBGRA.cols - tf->crop_left - tf->crop_right ,
777797 imageBGRA.rows - tf->crop_top - tf->crop_bottom );
778- obs_log (LOG_INFO, " Crop: %d %d %d %d" , cropRect.x , cropRect.y , cropRect.width ,
779- cropRect.height );
780798 cv::cvtColor (imageBGRA (cropRect), inferenceFrame, cv::COLOR_BGRA2BGR);
781799 } else {
782800 cv::cvtColor (imageBGRA, inferenceFrame, cv::COLOR_BGRA2BGR);
783801 }
784802
785- std::vector<edgeyolo_cpp:: Object> objects;
803+ std::vector<Object> objects;
786804
787805 try {
788806 std::unique_lock<std::mutex> lock (tf->modelMutex );
789- objects = tf->edgeyolo ->inference (inferenceFrame);
807+ objects = tf->onnxruntimemodel ->inference (inferenceFrame);
790808 } catch (const Ort::Exception &e) {
791809 obs_log (LOG_ERROR, " ONNXRuntime Exception: %s" , e.what ());
792810 } catch (const std::exception &e) {
@@ -795,7 +813,7 @@ void detect_filter_video_tick(void *data, float seconds)
795813
796814 if (tf->crop_enabled ) {
797815 // translate the detected objects to the original frame
798- for (edgeyolo_cpp:: Object &obj : objects) {
816+ for (Object &obj : objects) {
799817 obj.rect .x += (float )cropRect.x ;
800818 obj.rect .y += (float )cropRect.y ;
801819 }
@@ -824,8 +842,8 @@ void detect_filter_video_tick(void *data, float seconds)
824842 }
825843
826844 if (tf->objectCategory != -1 ) {
827- std::vector<edgeyolo_cpp:: Object> filtered_objects;
828- for (const edgeyolo_cpp:: Object &obj : objects) {
845+ std::vector<Object> filtered_objects;
846+ for (const Object &obj : objects) {
829847 if (obj.label == tf->objectCategory ) {
830848 filtered_objects.push_back (obj);
831849 }
@@ -838,18 +856,17 @@ void detect_filter_video_tick(void *data, float seconds)
838856 }
839857
840858 if (!tf->showUnseenObjects ) {
841- objects.erase (std::remove_if (objects.begin (), objects.end (),
842- [](const edgeyolo_cpp::Object &obj) {
843- return obj.unseenFrames > 0 ;
844- }),
845- objects.end ());
859+ objects.erase (
860+ std::remove_if (objects.begin (), objects.end (),
861+ [](const Object &obj) { return obj.unseenFrames > 0 ; }),
862+ objects.end ());
846863 }
847864
848865 if (!tf->saveDetectionsPath .empty ()) {
849866 std::ofstream detectionsFile (tf->saveDetectionsPath );
850867 if (detectionsFile.is_open ()) {
851868 nlohmann::json j;
852- for (const edgeyolo_cpp:: Object &obj : objects) {
869+ for (const Object &obj : objects) {
853870 nlohmann::json obj_json;
854871 obj_json[" label" ] = obj.label ;
855872 obj_json[" confidence" ] = obj.prob ;
@@ -877,11 +894,11 @@ void detect_filter_video_tick(void *data, float seconds)
877894 drawDashedRectangle (frame, cropRect, cv::Scalar (0 , 255 , 0 ), 5 , 8 , 15 );
878895 }
879896 if (tf->preview && objects.size () > 0 ) {
880- edgeyolo_cpp::utils:: draw_objects (frame, objects, tf->classNames );
897+ draw_objects (frame, objects, tf->classNames );
881898 }
882899 if (tf->maskingEnabled ) {
883900 cv::Mat mask = cv::Mat::zeros (frame.size (), CV_8UC1);
884- for (const edgeyolo_cpp:: Object &obj : objects) {
901+ for (const Object &obj : objects) {
885902 cv::rectangle (mask, obj.rect , cv::Scalar (255 ), -1 );
886903 }
887904 std::lock_guard<std::mutex> lock (tf->outputLock );
@@ -906,7 +923,7 @@ void detect_filter_video_tick(void *data, float seconds)
906923 // get the bounding box of all objects
907924 if (objects.size () > 0 ) {
908925 boundingBox = objects[0 ].rect ;
909- for (const edgeyolo_cpp:: Object &obj : objects) {
926+ for (const Object &obj : objects) {
910927 boundingBox |= obj.rect ;
911928 }
912929 }
@@ -967,7 +984,7 @@ void detect_filter_video_render(void *data, gs_effect_t *_effect)
967984
968985 struct detect_filter *tf = reinterpret_cast <detect_filter *>(data);
969986
970- if (tf->isDisabled || !tf->edgeyolo ) {
987+ if (tf->isDisabled || !tf->onnxruntimemodel ) {
971988 if (tf->source ) {
972989 obs_source_skip_video_filter (tf->source );
973990 }
0 commit comments