Skip to content

Commit fa97e64

Browse files
Updated initializer and SegmentAnything modules to store the data to the custom result structs properly
1 parent 959a3ff commit fa97e64

File tree

5 files changed

+22
-18
lines changed

5 files changed

+22
-18
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ include_directories(${OpenCV_INCLUDE_DIRS})
1919

2020
# -------------- ONNXRuntime ------------------#
2121
set(ONNXRUNTIME_VERSION 1.21.0)
22-
set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../hero_sam/onnxruntime-linux-x64-gpu-1.21.1")
22+
set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../hero_sam.bak/onnxruntime-linux-x64-gpu-1.21.1")
2323
include_directories(${ONNXRUNTIME_ROOT}/include)
2424

2525
# -------------- Cuda ------------------#
@@ -84,8 +84,8 @@ add_executable(${PROJECT_NAME} src/main.cpp)
8484
target_link_libraries(${PROJECT_NAME} sam_onnx_ros_core)
8585

8686
# Copy sam_<model>.onnx file to the same folder of the executable file
87-
configure_file(~/Documents/repos/hero_sam/sam_inference/model/SAM_mask_decoder.onnx ${CMAKE_CURRENT_BINARY_DIR}/SAM_mask_decoder.onnx COPYONLY)
88-
configure_file(~/Documents/repos/hero_sam/sam_inference/model/SAM_encoder.onnx ${CMAKE_CURRENT_BINARY_DIR}/SAM_encoder.onnx COPYONLY)
87+
configure_file(~/Documents/repos/hero_sam.bak/sam_inference/model/SAM_mask_decoder.onnx ${CMAKE_CURRENT_BINARY_DIR}/SAM_mask_decoder.onnx COPYONLY)
88+
configure_file(~/Documents/repos/hero_sam.bak/sam_inference/model/SAM_encoder.onnx ${CMAKE_CURRENT_BINARY_DIR}/SAM_encoder.onnx COPYONLY)
8989

9090
# Create folder name images in the same folder of the executable file
9191
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD

include/segmentation.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
#include <tuple>
55

66
#include "sam_inference.h"
7-
std::tuple<std::vector<std::unique_ptr<SAM>>, SEG::_DL_INIT_PARAM, SEG::_DL_INIT_PARAM> Initializer();
8-
std::vector<cv::Mat> SegmentAnything(std::vector<std::unique_ptr<SAM>>& samSegmentors, const SEG::_DL_INIT_PARAM& params_encoder, const SEG::_DL_INIT_PARAM& params_decoder, cv::Mat& img);
7+
std::tuple<std::vector<std::unique_ptr<SAM>>, SEG::_DL_INIT_PARAM, SEG::_DL_INIT_PARAM, SEG::DL_RESULT, std::vector<SEG::DL_RESULT>> Initializer();
8+
void SegmentAnything(std::vector<std::unique_ptr<SAM>>& samSegmentors, const SEG::_DL_INIT_PARAM& params_encoder, const SEG::_DL_INIT_PARAM& params_decoder, const cv::Mat& img,
9+
std::vector<SEG::DL_RESULT> &resSam,
10+
SEG::DL_RESULT &res);
911

1012
#endif // SEGMENTATION_H

src/main.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ int main()
1010
std::vector<std::unique_ptr<SAM>> samSegmentors;
1111
SEG::DL_INIT_PARAM params_encoder;
1212
SEG::DL_INIT_PARAM params_decoder;
13-
std::tie(samSegmentors, params_encoder, params_decoder) = Initializer();
13+
std::vector<SEG::DL_RESULT> resSam;
14+
SEG::DL_RESULT res;
15+
std::tie(samSegmentors, params_encoder, params_decoder, res, resSam) = Initializer();
1416
std::filesystem::path current_path = std::filesystem::current_path();
1517
std::filesystem::path imgs_path = "/home/amigo/Documents/repos/hero_sam/sam_inference/build/images"; // current_path / <- you could use
1618
for (auto &i : std::filesystem::directory_iterator(imgs_path))
@@ -19,8 +21,8 @@ int main()
1921
{
2022
std::string img_path = i.path().string();
2123
cv::Mat img = cv::imread(img_path);
22-
std::vector<cv::Mat> masks;
23-
masks = SegmentAnything(samSegmentors, params_encoder, params_decoder, img);
24+
25+
SegmentAnything(samSegmentors, params_encoder, params_decoder, img, resSam, res);
2426

2527
}
2628
}

src/segmentation.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "segmentation.h"
22

33
std::tuple<std::vector<std::unique_ptr<SAM>>, SEG::DL_INIT_PARAM,
4-
SEG::DL_INIT_PARAM>
4+
SEG::DL_INIT_PARAM, SEG::DL_RESULT, std::vector<SEG::DL_RESULT>>
55
Initializer() {
66
std::vector<std::unique_ptr<SAM>> samSegmentors;
77
samSegmentors.push_back(std::make_unique<SAM>());
@@ -11,7 +11,8 @@ Initializer() {
1111
std::unique_ptr<SAM> samSegmentorDecoder = std::make_unique<SAM>();
1212
SEG::DL_INIT_PARAM params_encoder;
1313
SEG::DL_INIT_PARAM params_decoder;
14-
14+
SEG::DL_RESULT res;
15+
std::vector<SEG::DL_RESULT> resSam;
1516
params_encoder.rectConfidenceThreshold = 0.1;
1617
params_encoder.iouThreshold = 0.5;
1718
params_encoder.modelPath = "/home/amigo//Documents/repos/sam_onnx_ros/build/SAM_encoder.onnx";
@@ -31,22 +32,21 @@ Initializer() {
3132
samSegmentorDecoder->CreateSession(params_decoder);
3233
samSegmentors[0] = std::move(samSegmentorEncoder);
3334
samSegmentors[1] = std::move(samSegmentorDecoder);
34-
return {std::move(samSegmentors), params_encoder, params_decoder};
35+
return {std::move(samSegmentors), params_encoder, params_decoder, res, resSam};
3536
}
3637

37-
std::vector<cv::Mat>
38-
SegmentAnything(std::vector<std::unique_ptr<SAM>> &samSegmentors,
38+
void SegmentAnything(std::vector<std::unique_ptr<SAM>> &samSegmentors,
3939
const SEG::DL_INIT_PARAM &params_encoder,
40-
const SEG::DL_INIT_PARAM &params_decoder, cv::Mat &img) {
40+
const SEG::DL_INIT_PARAM &params_decoder, const cv::Mat &img, std::vector<SEG::DL_RESULT> &resSam,
41+
SEG::DL_RESULT &res) {
42+
4143

42-
std::vector<SEG::DL_RESULT> resSam;
43-
SEG::DL_RESULT res;
4444

4545
SEG::MODEL_TYPE modelTypeRef = params_encoder.modelType;
4646
samSegmentors[0]->RunSession(img, resSam, modelTypeRef, res);
4747

4848
modelTypeRef = params_decoder.modelType;
4949
samSegmentors[1]->RunSession(img, resSam, modelTypeRef, res);
5050

51-
return std::move(res.masks);
51+
// return std::move(res.masks);
5252
}

src/utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "utils.h"
22
#include <opencv2/ximgproc/edge_filter.hpp> // for guided filter
3-
#define LOGGING
3+
//#define LOGGING
44

55
// Constructor
66
Utils::Utils()

0 commit comments

Comments
 (0)