55 *
66 */
77
8- #include < iostream>
9-
108#include " SuperPoint.hpp"
119#include " Utility.hpp"
12- #include < opencv2/opencv.hpp>
13- #include < stdexcept>
1410
1511namespace
1612{
17- std::pair<std::vector<cv::KeyPoint>, cv::Mat> processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg,
18- float * dst, int borderRemove = 4 ,
19- float confidenceThresh = 0.015 , bool alignCorners = true );
13+ using KeyPointAndDesc = std::pair<std::vector<cv::KeyPoint>, cv::Mat>;
14+
15+ KeyPointAndDesc processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg, float * dst, int borderRemove = 4 ,
16+ float confidenceThresh = 0.015 , bool alignCorners = true );
17+
2018} // namespace
2119
2220int main (int argc, char * argv[])
@@ -46,15 +44,40 @@ int main(int argc, char* argv[])
4644 [](const auto & imagePath) { return cv::imread (imagePath, 0 ); });
4745
4846 std::vector<float > dst (Ort::SuperPoint::IMG_CHANNEL * Ort::SuperPoint::IMG_H * Ort::SuperPoint::IMG_W);
49- std::vector<std::pair<std::vector<cv::KeyPoint>, cv::Mat>> results;
47+
48+ std::vector<KeyPointAndDesc> results;
5049 std::transform (grays.begin (), grays.end (), std::back_inserter (results),
5150 [&osh, &dst](const auto & gray) { return processOneFrame (osh, gray, dst.data ()); });
5251
52+ cv::Ptr<cv::DescriptorMatcher> matcher = cv::DescriptorMatcher::create (cv::DescriptorMatcher::FLANNBASED);
53+ std::vector<std::vector<cv::DMatch>> knnMatches;
54+ const int numMatch = 2 ;
55+ matcher->knnMatch (results[0 ].second , results[1 ].second , knnMatches, numMatch);
56+
57+ std::vector<cv::DMatch> goodMatches;
58+ const float loweRatioThresh = 0.8 ;
59+ for (const auto & match : knnMatches) {
60+ if (match[0 ].distance < loweRatioThresh * match[1 ].distance ) {
61+ goodMatches.emplace_back (match[0 ]);
62+ }
63+ }
64+
65+ cv::Mat matchesImage;
66+ cv::drawMatches (images[0 ], results[0 ].first , images[1 ], results[1 ].first , goodMatches, matchesImage,
67+ cv::Scalar::all (-1 ), cv::Scalar::all (-1 ), std::vector<char >(),
68+ cv::DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS);
69+ cv::imwrite (" good_matches.jpg" , matchesImage);
70+
5371 return EXIT_SUCCESS;
5472}
5573
5674namespace
5775{
76+ /* *
77+ * @brief detect super point
78+ *
79+ * @return vector of detected key points
80+ */
5881std::vector<cv::KeyPoint> getKeyPoints (const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
5982 int borderRemove, float confidenceThresh)
6083{
@@ -97,6 +120,11 @@ std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler:
97120 return keyPoints;
98121}
99122
123+ /* *
124+ * @brief estimate super point's keypoint descriptor
125+ *
126+ * @return keypoint Mat of shape [num key point x 256]
127+ */
100128cv::Mat getDescriptors (const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
101129 int width, bool alignCorners)
102130{
@@ -109,13 +137,16 @@ cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::K
109137 }
110138 keyPointMat = keyPointMat.reshape (1 , {1 , 1 , static_cast <int >(keyPoints.size ()), 2 });
111139 cv::Mat descriptors = bilinearGridSample (coarseDescriptors, keyPointMat, alignCorners);
140+ descriptors = descriptors.reshape (1 , {coarseDescriptors.size [1 ], static_cast <int >(keyPoints.size ())});
141+
142+ cv::Mat buffer;
143+ transposeNDWrapper (descriptors, {1 , 0 }, buffer);
112144
113- return descriptors. reshape ( 1 , {coarseDescriptors. size [ 1 ], static_cast < int >(keyPoints. size ())}) ;
145+ return buffer ;
114146}
115147
116- std::pair<std::vector<cv::KeyPoint>, cv::Mat> processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg,
117- float * dst, int borderRemove, float confidenceThresh,
118- bool alignCorners)
148+ KeyPointAndDesc processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg, float * dst, int borderRemove,
149+ float confidenceThresh, bool alignCorners)
119150{
120151 int origW = inputImg.cols , origH = inputImg.rows ;
121152 std::vector<float > originImageSize{static_cast <float >(origH), static_cast <float >(origW)};
0 commit comments