Skip to content
This repository was archived by the owner on Dec 18, 2024. It is now read-only.

Commit 6ccfecb

Browse files
authored
Merge pull request #19 from occ-ai/roy.generic_onnx_rt_model_support
feat: Add YuNet model support for object detection
2 parents 1d80702 + a8e33d0 commit 6ccfecb

17 files changed

+712
-414
lines changed

CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ target_sources(
8787
src/detect-filter-info.c
8888
src/detect-filter-utils.cpp
8989
src/obs-utils/obs-utils.cpp
90+
src/ort-model/ONNXRuntimeModel.cpp
9091
src/edgeyolo/edgeyolo_onnxruntime.cpp
91-
src/sort/Sort.cpp)
92+
src/sort/Sort.cpp
93+
src/yunet/YuNet.cpp)
9294

9395
set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})
227 KB
Binary file not shown.

src/FilterData.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define FILTERDATA_H
33

44
#include <obs-module.h>
5-
#include "edgeyolo/edgeyolo_onnxruntime.hpp"
5+
#include "ort-model/ONNXRuntimeModel.h"
66
#include "sort/Sort.h"
77

88
/**
@@ -58,7 +58,7 @@ struct filter_data {
5858
std::mutex outputLock;
5959
std::mutex modelMutex;
6060

61-
std::unique_ptr<edgeyolo_cpp::EdgeYOLOONNXRuntime> edgeyolo;
61+
std::unique_ptr<ONNXRuntimeModel> onnxruntimemodel;
6262
std::vector<std::string> classNames;
6363

6464
#if _WIN32

src/detect-filter.cpp

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
#include "FilterData.h"
2525
#include "consts.h"
2626
#include "obs-utils/obs-utils.h"
27-
#include "edgeyolo/utils.hpp"
27+
#include "ort-model/utils.hpp"
2828
#include "detect-filter-utils.h"
29+
#include "edgeyolo/edgeyolo_onnxruntime.hpp"
30+
#include "yunet/YuNet.h"
2931

3032
#define EXTERNAL_MODEL_SIZE "!!!EXTERNAL_MODEL!!!"
33+
#define FACE_DETECT_MODEL_SIZE "!!!FACE_DETECT!!!"
3134

3235
struct detect_filter : public filter_data {};
3336

@@ -325,6 +328,8 @@ obs_properties_t *detect_filter_properties(void *data)
325328
obs_property_list_add_string(model_size, obs_module_text("SmallFast"), "small");
326329
obs_property_list_add_string(model_size, obs_module_text("Medium"), "medium");
327330
obs_property_list_add_string(model_size, obs_module_text("LargeSlow"), "large");
331+
obs_property_list_add_string(model_size, obs_module_text("FaceDetect"),
332+
FACE_DETECT_MODEL_SIZE);
328333
obs_property_list_add_string(model_size, obs_module_text("ExternalModel"),
329334
EXTERNAL_MODEL_SIZE);
330335

@@ -513,6 +518,9 @@ void detect_filter_update(void *data, obs_data_t *settings)
513518
} else if (newModelSize == "large") {
514519
modelFilepath_rawPtr =
515520
obs_module_file("models/edgeyolo_tiny_lrelu_coco_736x1280.onnx");
521+
} else if (newModelSize == FACE_DETECT_MODEL_SIZE) {
522+
modelFilepath_rawPtr =
523+
obs_module_file("models/face_detection_yunet_2023mar.onnx");
516524
} else if (newModelSize == EXTERNAL_MODEL_SIZE) {
517525
const char *external_model_file =
518526
obs_data_get_string(settings, "external_model_file");
@@ -580,41 +588,53 @@ void detect_filter_update(void *data, obs_data_t *settings)
580588
obs_log(LOG_ERROR,
581589
"JSON file does not contain 'labels' field");
582590
tf->isDisabled = true;
583-
tf->edgeyolo.reset();
591+
tf->onnxruntimemodel.reset();
584592
return;
585593
}
586594
} else {
587595
obs_log(LOG_ERROR, "Failed to open JSON file: %s",
588596
labelsFilepath.c_str());
589597
tf->isDisabled = true;
590-
tf->edgeyolo.reset();
598+
tf->onnxruntimemodel.reset();
591599
return;
592600
}
601+
} else if (tf->modelSize == FACE_DETECT_MODEL_SIZE) {
602+
num_classes_ = 1;
603+
tf->classNames = yunet::FACE_CLASSES;
593604
}
594605

595606
// Load model
596607
try {
597-
if (tf->edgeyolo) {
598-
tf->edgeyolo.reset();
608+
if (tf->onnxruntimemodel) {
609+
tf->onnxruntimemodel.reset();
610+
}
611+
if (tf->modelSize == FACE_DETECT_MODEL_SIZE) {
612+
tf->onnxruntimemodel = std::make_unique<yunet::YuNetONNX>(
613+
tf->modelFilepath, tf->numThreads, 50, tf->numThreads,
614+
tf->useGPU, onnxruntime_device_id_,
615+
onnxruntime_use_parallel_, nms_th_, tf->conf_threshold);
616+
} else {
617+
tf->onnxruntimemodel =
618+
std::make_unique<edgeyolo_cpp::EdgeYOLOONNXRuntime>(
619+
tf->modelFilepath, tf->numThreads, num_classes_,
620+
tf->numThreads, tf->useGPU, onnxruntime_device_id_,
621+
onnxruntime_use_parallel_, nms_th_,
622+
tf->conf_threshold);
599623
}
600-
tf->edgeyolo = std::make_unique<edgeyolo_cpp::EdgeYOLOONNXRuntime>(
601-
tf->modelFilepath, tf->numThreads, tf->numThreads, tf->useGPU,
602-
onnxruntime_device_id_, onnxruntime_use_parallel_, nms_th_,
603-
tf->conf_threshold, num_classes_);
604624
// clear error message
605625
obs_data_set_string(settings, "error", "");
606626
} catch (const std::exception &e) {
607627
obs_log(LOG_ERROR, "Failed to load model: %s", e.what());
608628
// disable filter
609629
tf->isDisabled = true;
610-
tf->edgeyolo.reset();
630+
tf->onnxruntimemodel.reset();
611631
return;
612632
}
613633
}
614634

615635
// update threshold on edgeyolo
616-
if (tf->edgeyolo) {
617-
tf->edgeyolo->setBBoxConfThresh(tf->conf_threshold);
636+
if (tf->onnxruntimemodel) {
637+
tf->onnxruntimemodel->setBBoxConfThresh(tf->conf_threshold);
618638
}
619639

620640
if (reinitialize) {
@@ -746,7 +766,7 @@ void detect_filter_video_tick(void *data, float seconds)
746766

747767
struct detect_filter *tf = reinterpret_cast<detect_filter *>(data);
748768

749-
if (tf->isDisabled || !tf->edgeyolo) {
769+
if (tf->isDisabled || !tf->onnxruntimemodel) {
750770
return;
751771
}
752772

@@ -775,18 +795,16 @@ void detect_filter_video_tick(void *data, float seconds)
775795
cropRect = cv::Rect(tf->crop_left, tf->crop_top,
776796
imageBGRA.cols - tf->crop_left - tf->crop_right,
777797
imageBGRA.rows - tf->crop_top - tf->crop_bottom);
778-
obs_log(LOG_INFO, "Crop: %d %d %d %d", cropRect.x, cropRect.y, cropRect.width,
779-
cropRect.height);
780798
cv::cvtColor(imageBGRA(cropRect), inferenceFrame, cv::COLOR_BGRA2BGR);
781799
} else {
782800
cv::cvtColor(imageBGRA, inferenceFrame, cv::COLOR_BGRA2BGR);
783801
}
784802

785-
std::vector<edgeyolo_cpp::Object> objects;
803+
std::vector<Object> objects;
786804

787805
try {
788806
std::unique_lock<std::mutex> lock(tf->modelMutex);
789-
objects = tf->edgeyolo->inference(inferenceFrame);
807+
objects = tf->onnxruntimemodel->inference(inferenceFrame);
790808
} catch (const Ort::Exception &e) {
791809
obs_log(LOG_ERROR, "ONNXRuntime Exception: %s", e.what());
792810
} catch (const std::exception &e) {
@@ -795,7 +813,7 @@ void detect_filter_video_tick(void *data, float seconds)
795813

796814
if (tf->crop_enabled) {
797815
// translate the detected objects to the original frame
798-
for (edgeyolo_cpp::Object &obj : objects) {
816+
for (Object &obj : objects) {
799817
obj.rect.x += (float)cropRect.x;
800818
obj.rect.y += (float)cropRect.y;
801819
}
@@ -824,8 +842,8 @@ void detect_filter_video_tick(void *data, float seconds)
824842
}
825843

826844
if (tf->objectCategory != -1) {
827-
std::vector<edgeyolo_cpp::Object> filtered_objects;
828-
for (const edgeyolo_cpp::Object &obj : objects) {
845+
std::vector<Object> filtered_objects;
846+
for (const Object &obj : objects) {
829847
if (obj.label == tf->objectCategory) {
830848
filtered_objects.push_back(obj);
831849
}
@@ -838,18 +856,17 @@ void detect_filter_video_tick(void *data, float seconds)
838856
}
839857

840858
if (!tf->showUnseenObjects) {
841-
objects.erase(std::remove_if(objects.begin(), objects.end(),
842-
[](const edgeyolo_cpp::Object &obj) {
843-
return obj.unseenFrames > 0;
844-
}),
845-
objects.end());
859+
objects.erase(
860+
std::remove_if(objects.begin(), objects.end(),
861+
[](const Object &obj) { return obj.unseenFrames > 0; }),
862+
objects.end());
846863
}
847864

848865
if (!tf->saveDetectionsPath.empty()) {
849866
std::ofstream detectionsFile(tf->saveDetectionsPath);
850867
if (detectionsFile.is_open()) {
851868
nlohmann::json j;
852-
for (const edgeyolo_cpp::Object &obj : objects) {
869+
for (const Object &obj : objects) {
853870
nlohmann::json obj_json;
854871
obj_json["label"] = obj.label;
855872
obj_json["confidence"] = obj.prob;
@@ -877,11 +894,11 @@ void detect_filter_video_tick(void *data, float seconds)
877894
drawDashedRectangle(frame, cropRect, cv::Scalar(0, 255, 0), 5, 8, 15);
878895
}
879896
if (tf->preview && objects.size() > 0) {
880-
edgeyolo_cpp::utils::draw_objects(frame, objects, tf->classNames);
897+
draw_objects(frame, objects, tf->classNames);
881898
}
882899
if (tf->maskingEnabled) {
883900
cv::Mat mask = cv::Mat::zeros(frame.size(), CV_8UC1);
884-
for (const edgeyolo_cpp::Object &obj : objects) {
901+
for (const Object &obj : objects) {
885902
cv::rectangle(mask, obj.rect, cv::Scalar(255), -1);
886903
}
887904
std::lock_guard<std::mutex> lock(tf->outputLock);
@@ -906,7 +923,7 @@ void detect_filter_video_tick(void *data, float seconds)
906923
// get the bounding box of all objects
907924
if (objects.size() > 0) {
908925
boundingBox = objects[0].rect;
909-
for (const edgeyolo_cpp::Object &obj : objects) {
926+
for (const Object &obj : objects) {
910927
boundingBox |= obj.rect;
911928
}
912929
}
@@ -967,7 +984,7 @@ void detect_filter_video_render(void *data, gs_effect_t *_effect)
967984

968985
struct detect_filter *tf = reinterpret_cast<detect_filter *>(data);
969986

970-
if (tf->isDisabled || !tf->edgeyolo) {
987+
if (tf->isDisabled || !tf->onnxruntimemodel) {
971988
if (tf->source) {
972989
obs_source_skip_video_filter(tf->source);
973990
}

src/edgeyolo/coco_names.hpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,33 +30,5 @@ static const std::vector<std::string> COCO_CLASSES = {
3030
"refrigerator", "book", "clock",
3131
"vase", "scissors", "teddy bear",
3232
"hair drier", "toothbrush"};
33-
const float color_list[80][3] = {
34-
{0.000f, 0.447f, 0.741f}, {0.850f, 0.325f, 0.098f}, {0.929f, 0.694f, 0.125f},
35-
{0.494f, 0.184f, 0.556f}, {0.466f, 0.674f, 0.188f}, {0.301f, 0.745f, 0.933f},
36-
{0.635f, 0.078f, 0.184f}, {0.300f, 0.300f, 0.300f}, {0.600f, 0.600f, 0.600f},
37-
{1.000f, 0.000f, 0.000f}, {1.000f, 0.500f, 0.000f}, {0.749f, 0.749f, 0.000f},
38-
{0.000f, 1.000f, 0.000f}, {0.000f, 0.000f, 1.000f}, {0.667f, 0.000f, 1.000f},
39-
{0.333f, 0.333f, 0.000f}, {0.333f, 0.667f, 0.000f}, {0.333f, 1.000f, 0.000f},
40-
{0.667f, 0.333f, 0.000f}, {0.667f, 0.667f, 0.000f}, {0.667f, 1.000f, 0.000f},
41-
{1.000f, 0.333f, 0.000f}, {1.000f, 0.667f, 0.000f}, {1.000f, 1.000f, 0.000f},
42-
{0.000f, 0.333f, 0.500f}, {0.000f, 0.667f, 0.500f}, {0.000f, 1.000f, 0.500f},
43-
{0.333f, 0.000f, 0.500f}, {0.333f, 0.333f, 0.500f}, {0.333f, 0.667f, 0.500f},
44-
{0.333f, 1.000f, 0.500f}, {0.667f, 0.000f, 0.500f}, {0.667f, 0.333f, 0.500f},
45-
{0.667f, 0.667f, 0.500f}, {0.667f, 1.000f, 0.500f}, {1.000f, 0.000f, 0.500f},
46-
{1.000f, 0.333f, 0.500f}, {1.000f, 0.667f, 0.500f}, {1.000f, 1.000f, 0.500f},
47-
{0.000f, 0.333f, 1.000f}, {0.000f, 0.667f, 1.000f}, {0.000f, 1.000f, 1.000f},
48-
{0.333f, 0.000f, 1.000f}, {0.333f, 0.333f, 1.000f}, {0.333f, 0.667f, 1.000f},
49-
{0.333f, 1.000f, 1.000f}, {0.667f, 0.000f, 1.000f}, {0.667f, 0.333f, 1.000f},
50-
{0.667f, 0.667f, 1.000f}, {0.667f, 1.000f, 1.000f}, {1.000f, 0.000f, 1.000f},
51-
{1.000f, 0.333f, 1.000f}, {1.000f, 0.667f, 1.000f}, {0.333f, 0.000f, 0.000f},
52-
{0.500f, 0.000f, 0.000f}, {0.667f, 0.000f, 0.000f}, {0.833f, 0.000f, 0.000f},
53-
{1.000f, 0.000f, 0.000f}, {0.000f, 0.167f, 0.000f}, {0.000f, 0.333f, 0.000f},
54-
{0.000f, 0.500f, 0.000f}, {0.000f, 0.667f, 0.000f}, {0.000f, 0.833f, 0.000f},
55-
{0.000f, 1.000f, 0.000f}, {0.000f, 0.000f, 0.167f}, {0.000f, 0.000f, 0.333f},
56-
{0.000f, 0.000f, 0.500f}, {0.000f, 0.000f, 0.667f}, {0.000f, 0.000f, 0.833f},
57-
{0.000f, 0.000f, 1.000f}, {0.000f, 0.000f, 0.000f}, {0.143f, 0.143f, 0.143f},
58-
{0.286f, 0.286f, 0.286f}, {0.429f, 0.429f, 0.429f}, {0.571f, 0.571f, 0.571f},
59-
{0.714f, 0.714f, 0.714f}, {0.857f, 0.857f, 0.857f}, {0.000f, 0.447f, 0.741f},
60-
{0.314f, 0.717f, 0.741f}, {0.50f, 0.5f, 0.0f}};
6133
} // namespace edgeyolo_cpp
6234
#endif

0 commit comments

Comments
 (0)