Skip to content

Commit 8e78dcf

Browse files
committed
extract superpoint without nms
1 parent d1fcb67 commit 8e78dcf

11 files changed

+391
-22
lines changed

examples/CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,21 @@ target_include_directories(semantic_segmentation_paddleseg_bisenetv2
141141
PUBLIC
142142
${OpenCV_INCLUDE_DIRS}
143143
)
144+
145+
# ---------------------------------------------------------
146+
147+
add_executable(super_point
148+
${CMAKE_CURRENT_LIST_DIR}/SuperPoint.cpp
149+
${CMAKE_CURRENT_LIST_DIR}/SuperPointApp.cpp
150+
)
151+
152+
target_link_libraries(super_point
153+
PUBLIC
154+
${PROJECT_NAME}
155+
${OpenCV_LIBS}
156+
)
157+
158+
target_include_directories(super_point
159+
PUBLIC
160+
${OpenCV_INCLUDE_DIRS}
161+
)

examples/SuperPoint.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/**
2+
* @file SuperPoint.cpp
3+
*
4+
* @author btran
5+
*
6+
*/
7+
8+
#include "SuperPoint.hpp"
9+
10+
namespace Ort
11+
{
12+
void SuperPoint::preprocess(float* dst, const unsigned char* src, const int64_t targetImgWidth,
13+
const int64_t targetImgHeight, const int numChannels) const
14+
{
15+
for (int i = 0; i < targetImgHeight; ++i) {
16+
for (int j = 0; j < targetImgWidth; ++j) {
17+
for (int c = 0; c < numChannels; ++c) {
18+
dst[c * targetImgHeight * targetImgWidth + i * targetImgWidth + j] =
19+
(src[i * targetImgWidth * numChannels + j * numChannels + c] / 255.0);
20+
}
21+
}
22+
}
23+
}
24+
} // namespace Ort
25+
// namespace Ort

examples/SuperPoint.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/**
2+
* @file SuperPoint.hpp
3+
*
4+
* @author btran
5+
*
6+
*/
7+
8+
#pragma once
9+
10+
#include <ort_utility/ort_utility.hpp>
11+
12+
namespace Ort
13+
{
14+
class SuperPoint : public OrtSessionHandler
15+
{
16+
public:
17+
static constexpr int64_t IMG_H = 480;
18+
static constexpr int64_t IMG_W = 640;
19+
static constexpr int64_t IMG_CHANNEL = 1;
20+
21+
using OrtSessionHandler::OrtSessionHandler;
22+
23+
void preprocess(float* dst, //
24+
const unsigned char* src, //
25+
const int64_t targetImgWidth, //
26+
const int64_t targetImgHeight, //
27+
const int numChannels) const;
28+
};
29+
} // namespace Ort

examples/SuperPointApp.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/**
2+
* @file SuperPointApp.cpp
3+
*
4+
* @author btran
5+
*
6+
*/
7+
8+
#include <iostream>
9+
10+
#include "SuperPoint.hpp"
11+
#include "Utility.hpp"
12+
#include <opencv2/opencv.hpp>
13+
#include <stdexcept>
14+
15+
namespace
16+
{
17+
std::pair<std::vector<cv::KeyPoint>, cv::Mat> processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg,
18+
float* dst, int borderRemove = 4,
19+
float confidenceThresh = 0.015, bool alignCorners = true);
20+
} // namespace
21+
22+
int main(int argc, char* argv[])
23+
{
24+
if (argc != 4) {
25+
std::cerr << "Usage: [apps] [path/to/onnx/super/point] [path/to/image1] [path/to/image2]" << std::endl;
26+
return EXIT_FAILURE;
27+
}
28+
29+
const std::string ONNX_MODEL_PATH = argv[1];
30+
const std::vector<std::string> IMAGE_PATHS = {argv[2], argv[3]};
31+
32+
Ort::SuperPoint osh(ONNX_MODEL_PATH, 0,
33+
std::vector<std::vector<int64_t>>{
34+
{1, Ort::SuperPoint::IMG_CHANNEL, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W}});
35+
36+
std::vector<cv::Mat> images;
37+
std::vector<cv::Mat> grays;
38+
std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(images),
39+
[](const auto& imagePath) { return cv::imread(imagePath); });
40+
for (int i = 0; i < 2; ++i) {
41+
if (images[i].empty()) {
42+
throw std::runtime_error("failed to open " + IMAGE_PATHS[i]);
43+
}
44+
}
45+
std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(grays),
46+
[](const auto& imagePath) { return cv::imread(imagePath, 0); });
47+
48+
std::vector<float> dst(Ort::SuperPoint::IMG_CHANNEL * Ort::SuperPoint::IMG_H * Ort::SuperPoint::IMG_W);
49+
std::vector<std::pair<std::vector<cv::KeyPoint>, cv::Mat>> results;
50+
std::transform(grays.begin(), grays.end(), std::back_inserter(results),
51+
[&osh, &dst](const auto& gray) { return processOneFrame(osh, gray, dst.data()); });
52+
53+
return EXIT_SUCCESS;
54+
}
55+
56+
namespace
57+
{
58+
std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
59+
int borderRemove, float confidenceThresh)
60+
{
61+
std::vector<int> detectorShape(inferenceOutput[0].second.begin() + 1, inferenceOutput[0].second.end());
62+
63+
cv::Mat detectorMat(detectorShape.size(), detectorShape.data(), CV_32F,
64+
inferenceOutput[0].first); // 65 x H/8 x W/8
65+
cv::Mat buffer;
66+
67+
transposeNDWrapper(detectorMat, {1, 2, 0}, buffer);
68+
buffer.copyTo(detectorMat); // H/8 x W/8 x 65
69+
70+
for (int i = 0; i < detectorShape[1]; ++i) {
71+
for (int j = 0; j < detectorShape[2]; ++j) {
72+
Ort::softmax(detectorMat.ptr<float>(i, j), detectorShape[0]);
73+
}
74+
}
75+
detectorMat = detectorMat({cv::Range::all(), cv::Range::all(), cv::Range(0, detectorShape[0] - 1)})
76+
.clone(); // H/8 x W/8 x 64
77+
detectorMat = detectorMat.reshape(1, {detectorShape[1], detectorShape[2], 8, 8}); // H/8 x W/8 x 8 x 8
78+
transposeNDWrapper(detectorMat, {0, 2, 1, 3}, buffer);
79+
buffer.copyTo(detectorMat); // H/8 x 8 x W/8 x 8
80+
81+
detectorMat = detectorMat.reshape(1, {detectorShape[1] * 8, detectorShape[2] * 8}); // H x W
82+
83+
std::vector<cv::KeyPoint> keyPoints;
84+
for (int i = borderRemove; i < detectorMat.rows - borderRemove; ++i) {
85+
auto rowPtr = detectorMat.ptr<float>(i);
86+
for (int j = borderRemove; j < detectorMat.cols - borderRemove; ++j) {
87+
if (rowPtr[j] > confidenceThresh) {
88+
cv::KeyPoint keyPoint;
89+
keyPoint.pt.x = j;
90+
keyPoint.pt.y = i;
91+
keyPoint.response = rowPtr[j];
92+
keyPoints.emplace_back(keyPoint);
93+
}
94+
}
95+
}
96+
97+
return keyPoints;
98+
}
99+
100+
cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
101+
int width, bool alignCorners)
102+
{
103+
cv::Mat keyPointMat(keyPoints.size(), 2, CV_32F);
104+
105+
for (int i = 0; i < keyPoints.size(); ++i) {
106+
auto rowPtr = keyPointMat.ptr<float>(i);
107+
rowPtr[0] = 2 * keyPoints[i].pt.y / (height - 1) - 1;
108+
rowPtr[1] = 2 * keyPoints[i].pt.x / (width - 1) - 1;
109+
}
110+
keyPointMat = keyPointMat.reshape(1, {1, 1, static_cast<int>(keyPoints.size()), 2});
111+
cv::Mat descriptors = bilinearGridSample(coarseDescriptors, keyPointMat, alignCorners);
112+
113+
return descriptors.reshape(1, {coarseDescriptors.size[1], static_cast<int>(keyPoints.size())});
114+
}
115+
116+
std::pair<std::vector<cv::KeyPoint>, cv::Mat> processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg,
117+
float* dst, int borderRemove, float confidenceThresh,
118+
bool alignCorners)
119+
{
120+
int origW = inputImg.cols, origH = inputImg.rows;
121+
std::vector<float> originImageSize{static_cast<float>(origH), static_cast<float>(origW)};
122+
cv::Mat scaledImg;
123+
cv::resize(inputImg, scaledImg, cv::Size(Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H), 0, 0, cv::INTER_CUBIC);
124+
osh.preprocess(dst, scaledImg.data, Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_CHANNEL);
125+
auto inferenceOutput = osh({dst});
126+
127+
std::vector<cv::KeyPoint> keyPoints = getKeyPoints(inferenceOutput, borderRemove, confidenceThresh);
128+
129+
std::vector<int> descriptorShape(inferenceOutput[1].second.begin(), inferenceOutput[1].second.end());
130+
cv::Mat coarseDescriptorMat(descriptorShape.size(), descriptorShape.data(), CV_32F,
131+
inferenceOutput[1].first); // 1 x 256 x H/8 x W/8
132+
133+
cv::Mat descriptors =
134+
getDescriptors(coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, alignCorners);
135+
136+
for (auto& keyPoint : keyPoints) {
137+
keyPoint.pt.x *= static_cast<float>(origW) / Ort::SuperPoint::IMG_W;
138+
keyPoint.pt.y *= static_cast<float>(origH) / Ort::SuperPoint::IMG_H;
139+
}
140+
141+
return {keyPoints, descriptors};
142+
}
143+
} // namespace

0 commit comments

Comments
 (0)