Skip to content

Commit 7f30d9d

Browse files
committed
match key points
1 parent 8e78dcf commit 7f30d9d

File tree

1 file changed

+43
-12
lines changed

1 file changed

+43
-12
lines changed

examples/SuperPointApp.cpp

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,16 @@
55
*
66
*/
77

8-
#include <iostream>
9-
108
#include "SuperPoint.hpp"
119
#include "Utility.hpp"
12-
#include <opencv2/opencv.hpp>
13-
#include <stdexcept>
1410

1511
namespace
1612
{
17-
std::pair<std::vector<cv::KeyPoint>, cv::Mat> processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg,
18-
float* dst, int borderRemove = 4,
19-
float confidenceThresh = 0.015, bool alignCorners = true);
13+
using KeyPointAndDesc = std::pair<std::vector<cv::KeyPoint>, cv::Mat>;
14+
15+
KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg, float* dst, int borderRemove = 4,
16+
float confidenceThresh = 0.015, bool alignCorners = true);
17+
2018
} // namespace
2119

2220
int main(int argc, char* argv[])
@@ -46,15 +44,40 @@ int main(int argc, char* argv[])
4644
[](const auto& imagePath) { return cv::imread(imagePath, 0); });
4745

4846
std::vector<float> dst(Ort::SuperPoint::IMG_CHANNEL * Ort::SuperPoint::IMG_H * Ort::SuperPoint::IMG_W);
49-
std::vector<std::pair<std::vector<cv::KeyPoint>, cv::Mat>> results;
47+
48+
std::vector<KeyPointAndDesc> results;
5049
std::transform(grays.begin(), grays.end(), std::back_inserter(results),
5150
[&osh, &dst](const auto& gray) { return processOneFrame(osh, gray, dst.data()); });
5251

52+
cv::Ptr<cv::DescriptorMatcher> matcher = cv::DescriptorMatcher::create(cv::DescriptorMatcher::FLANNBASED);
53+
std::vector<std::vector<cv::DMatch>> knnMatches;
54+
const int numMatch = 2;
55+
matcher->knnMatch(results[0].second, results[1].second, knnMatches, numMatch);
56+
57+
std::vector<cv::DMatch> goodMatches;
58+
const float loweRatioThresh = 0.8;
59+
for (const auto& match : knnMatches) {
60+
if (match[0].distance < loweRatioThresh * match[1].distance) {
61+
goodMatches.emplace_back(match[0]);
62+
}
63+
}
64+
65+
cv::Mat matchesImage;
66+
cv::drawMatches(images[0], results[0].first, images[1], results[1].first, goodMatches, matchesImage,
67+
cv::Scalar::all(-1), cv::Scalar::all(-1), std::vector<char>(),
68+
cv::DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS);
69+
cv::imwrite("good_matches.jpg", matchesImage);
70+
5371
return EXIT_SUCCESS;
5472
}
5573

5674
namespace
5775
{
76+
/**
77+
* @brief detect super point
78+
*
79+
* @return vector of detected key points
80+
*/
5881
std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
5982
int borderRemove, float confidenceThresh)
6083
{
@@ -97,6 +120,11 @@ std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler:
97120
return keyPoints;
98121
}
99122

123+
/**
124+
* @brief estimate super point's keypoint descriptor
125+
*
126+
* @return keypoint Mat of shape [num key point x 256]
127+
*/
100128
cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
101129
int width, bool alignCorners)
102130
{
@@ -109,13 +137,16 @@ cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::K
109137
}
110138
keyPointMat = keyPointMat.reshape(1, {1, 1, static_cast<int>(keyPoints.size()), 2});
111139
cv::Mat descriptors = bilinearGridSample(coarseDescriptors, keyPointMat, alignCorners);
140+
descriptors = descriptors.reshape(1, {coarseDescriptors.size[1], static_cast<int>(keyPoints.size())});
141+
142+
cv::Mat buffer;
143+
transposeNDWrapper(descriptors, {1, 0}, buffer);
112144

113-
return descriptors.reshape(1, {coarseDescriptors.size[1], static_cast<int>(keyPoints.size())});
145+
return buffer;
114146
}
115147

116-
std::pair<std::vector<cv::KeyPoint>, cv::Mat> processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg,
117-
float* dst, int borderRemove, float confidenceThresh,
118-
bool alignCorners)
148+
KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg, float* dst, int borderRemove,
149+
float confidenceThresh, bool alignCorners)
119150
{
120151
int origW = inputImg.cols, origH = inputImg.rows;
121152
std::vector<float> originImageSize{static_cast<float>(origH), static_cast<float>(origW)};

0 commit comments

Comments
 (0)