77
88#include " SuperPoint.hpp"
99#include " Utility.hpp"
10+ #include < algorithm>
11+ #include < iterator>
1012
1113namespace
1214{
1315using KeyPointAndDesc = std::pair<std::vector<cv::KeyPoint>, cv::Mat>;
1416
1517KeyPointAndDesc processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg, float * dst, int borderRemove = 4 ,
16- float confidenceThresh = 0.015 , bool alignCorners = true );
18+ float confidenceThresh = 0.015 , bool alignCorners = true , int distThresh = 2 );
1719
20+ /* *
21+ * @brief detect super point
22+ *
23+ * @return vector of detected key points
24+ */
25+ std::vector<cv::KeyPoint> getKeyPoints (const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
26+ int borderRemove, float confidenceThresh);
27+
28+ /* *
29+ * @brief estimate super point's keypoint descriptor
30+ *
31+ * @return keypoint Mat of shape [num key point x 256]
32+ */
33+ cv::Mat getDescriptors (const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
34+ int width, bool alignCorners);
35+
36+ std::vector<int > nmsFast (const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh = 2 );
1837} // namespace
1938
2039int main (int argc, char * argv[])
@@ -66,18 +85,13 @@ int main(int argc, char* argv[])
6685 cv::drawMatches (images[0 ], results[0 ].first , images[1 ], results[1 ].first , goodMatches, matchesImage,
6786 cv::Scalar::all (-1 ), cv::Scalar::all (-1 ), std::vector<char >(),
6887 cv::DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS);
69- cv::imwrite (" good_matches .jpg" , matchesImage);
88+ cv::imwrite (" super_point_good_matches .jpg" , matchesImage);
7089
7190 return EXIT_SUCCESS;
7291}
7392
7493namespace
7594{
76- /* *
77- * @brief detect super point
78- *
79- * @return vector of detected key points
80- */
8195std::vector<cv::KeyPoint> getKeyPoints (const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
8296 int borderRemove, float confidenceThresh)
8397{
@@ -119,12 +133,6 @@ std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler:
119133
120134 return keyPoints;
121135}
122-
123- /* *
124- * @brief estimate super point's keypoint descriptor
125- *
126- * @return keypoint Mat of shape [num key point x 256]
127- */
128136cv::Mat getDescriptors (const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
129137 int width, bool alignCorners)
130138{
@@ -145,8 +153,50 @@ cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::K
145153 return buffer;
146154}
147155
156+ std::vector<int > nmsFast (const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh)
157+ {
158+ static const int TO_PROCESS = 1 ;
159+ static const int EMPTY_OR_SUPPRESSED = 0 ;
160+ static const int KEPT = -1 ;
161+
162+ std::vector<int > sortedIndices (keyPoints.size ());
163+ std::iota (sortedIndices.begin (), sortedIndices.end (), 0 );
164+
165+ // sort in descending order base on confidence
166+ std::stable_sort (sortedIndices.begin (), sortedIndices.end (),
167+ [&keyPoints](int lidx, int ridx) { return keyPoints[lidx].response > keyPoints[ridx].response ; });
168+
169+ cv::Mat grid = cv::Mat (height, width, CV_8S, TO_PROCESS);
170+ std::vector<int > keepIndices;
171+
172+ for (int idx : sortedIndices) {
173+ int x = keyPoints[idx].pt .x ;
174+ int y = keyPoints[idx].pt .y ;
175+
176+ if (grid.at <schar>(y, x) == TO_PROCESS) {
177+ for (int i = y - distThresh; i < y + distThresh; ++i) {
178+ if (i < 0 || i >= height) {
179+ continue ;
180+ }
181+
182+ for (int j = x - distThresh; j < x + distThresh; ++j) {
183+ if (j < 0 || j >= width) {
184+ continue ;
185+ }
186+ grid.at <int >(i, j) = EMPTY_OR_SUPPRESSED;
187+ }
188+ }
189+
190+ grid.at <int >(y, x) = KEPT;
191+ keepIndices.emplace_back (idx);
192+ }
193+ }
194+
195+ return keepIndices;
196+ }
197+
148198KeyPointAndDesc processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg, float * dst, int borderRemove,
149- float confidenceThresh, bool alignCorners)
199+ float confidenceThresh, bool alignCorners, int distThresh )
150200{
151201 int origW = inputImg.cols , origH = inputImg.rows ;
152202 std::vector<float > originImageSize{static_cast <float >(origH), static_cast <float >(origW)};
@@ -161,6 +211,14 @@ KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& input
161211 cv::Mat coarseDescriptorMat (descriptorShape.size (), descriptorShape.data (), CV_32F,
162212 inferenceOutput[1 ].first ); // 1 x 256 x H/8 x W/8
163213
214+ std::vector<int > keepIndices = nmsFast (keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh);
215+
216+ std::vector<cv::KeyPoint> keepKeyPoints;
217+ keepKeyPoints.reserve (keepIndices.size ());
218+ std::transform (keepIndices.begin (), keepIndices.end (), std::back_inserter (keepKeyPoints),
219+ [&keyPoints](int idx) { return keyPoints[idx]; });
220+ keyPoints = std::move (keepKeyPoints);
221+
164222 cv::Mat descriptors =
165223 getDescriptors (coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, alignCorners);
166224
0 commit comments