@@ -15,23 +15,6 @@ using KeyPointAndDesc = std::pair<std::vector<cv::KeyPoint>, cv::Mat>;
1515KeyPointAndDesc processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg, float * dst, int borderRemove = 4 ,
1616 float confidenceThresh = 0.015 , bool alignCorners = true , int distThresh = 2 );
1717
18- /* *
19- * @brief detect super point
20- *
21- * @return vector of detected key points
22- */
23- std::vector<cv::KeyPoint> getKeyPoints (const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
24- int borderRemove, float confidenceThresh);
25-
26- /* *
27- * @brief estimate super point's keypoint descriptor
28- *
29- * @return keypoint Mat of shape [num key point x 256]
30- */
31- cv::Mat getDescriptors (const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
32- int width, bool alignCorners);
33-
34- std::vector<int > nmsFast (const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh = 2 );
3518} // namespace
3619
3720int main (int argc, char * argv[])
@@ -92,109 +75,6 @@ int main(int argc, char* argv[])
9275
9376namespace
9477{
95- std::vector<cv::KeyPoint> getKeyPoints (const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
96- int borderRemove, float confidenceThresh)
97- {
98- std::vector<int > detectorShape (inferenceOutput[0 ].second .begin () + 1 , inferenceOutput[0 ].second .end ());
99-
100- cv::Mat detectorMat (detectorShape.size (), detectorShape.data (), CV_32F,
101- inferenceOutput[0 ].first ); // 65 x H/8 x W/8
102- cv::Mat buffer;
103-
104- transposeNDWrapper (detectorMat, {1 , 2 , 0 }, buffer);
105- buffer.copyTo (detectorMat); // H/8 x W/8 x 65
106-
107- for (int i = 0 ; i < detectorShape[1 ]; ++i) {
108- for (int j = 0 ; j < detectorShape[2 ]; ++j) {
109- Ort::softmax (detectorMat.ptr <float >(i, j), detectorShape[0 ]);
110- }
111- }
112- detectorMat = detectorMat ({cv::Range::all (), cv::Range::all (), cv::Range (0 , detectorShape[0 ] - 1 )})
113- .clone (); // H/8 x W/8 x 64
114- detectorMat = detectorMat.reshape (1 , {detectorShape[1 ], detectorShape[2 ], 8 , 8 }); // H/8 x W/8 x 8 x 8
115- transposeNDWrapper (detectorMat, {0 , 2 , 1 , 3 }, buffer);
116- buffer.copyTo (detectorMat); // H/8 x 8 x W/8 x 8
117-
118- detectorMat = detectorMat.reshape (1 , {detectorShape[1 ] * 8 , detectorShape[2 ] * 8 }); // H x W
119-
120- std::vector<cv::KeyPoint> keyPoints;
121- for (int i = borderRemove; i < detectorMat.rows - borderRemove; ++i) {
122- auto rowPtr = detectorMat.ptr <float >(i);
123- for (int j = borderRemove; j < detectorMat.cols - borderRemove; ++j) {
124- if (rowPtr[j] > confidenceThresh) {
125- cv::KeyPoint keyPoint;
126- keyPoint.pt .x = j;
127- keyPoint.pt .y = i;
128- keyPoint.response = rowPtr[j];
129- keyPoints.emplace_back (keyPoint);
130- }
131- }
132- }
133-
134- return keyPoints;
135- }
136- cv::Mat getDescriptors (const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
137- int width, bool alignCorners)
138- {
139- cv::Mat keyPointMat (keyPoints.size (), 2 , CV_32F);
140-
141- for (int i = 0 ; i < keyPoints.size (); ++i) {
142- auto rowPtr = keyPointMat.ptr <float >(i);
143- rowPtr[0 ] = 2 * keyPoints[i].pt .y / (height - 1 ) - 1 ;
144- rowPtr[1 ] = 2 * keyPoints[i].pt .x / (width - 1 ) - 1 ;
145- }
146- keyPointMat = keyPointMat.reshape (1 , {1 , 1 , static_cast <int >(keyPoints.size ()), 2 });
147- cv::Mat descriptors = bilinearGridSample (coarseDescriptors, keyPointMat, alignCorners);
148- descriptors = descriptors.reshape (1 , {coarseDescriptors.size [1 ], static_cast <int >(keyPoints.size ())});
149-
150- cv::Mat buffer;
151- transposeNDWrapper (descriptors, {1 , 0 }, buffer);
152-
153- return buffer;
154- }
155-
156- std::vector<int > nmsFast (const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh)
157- {
158- static const int TO_PROCESS = 1 ;
159- static const int EMPTY_OR_SUPPRESSED = 0 ;
160- static const int KEPT = -1 ;
161-
162- std::vector<int > sortedIndices (keyPoints.size ());
163- std::iota (sortedIndices.begin (), sortedIndices.end (), 0 );
164-
165- // sort in descending order base on confidence
166- std::stable_sort (sortedIndices.begin (), sortedIndices.end (),
167- [&keyPoints](int lidx, int ridx) { return keyPoints[lidx].response > keyPoints[ridx].response ; });
168-
169- cv::Mat grid = cv::Mat (height, width, CV_8S, TO_PROCESS);
170- std::vector<int > keepIndices;
171-
172- for (int idx : sortedIndices) {
173- int x = keyPoints[idx].pt .x ;
174- int y = keyPoints[idx].pt .y ;
175-
176- if (grid.at <schar>(y, x) == TO_PROCESS) {
177- for (int i = y - distThresh; i < y + distThresh; ++i) {
178- if (i < 0 || i >= height) {
179- continue ;
180- }
181-
182- for (int j = x - distThresh; j < x + distThresh; ++j) {
183- if (j < 0 || j >= width) {
184- continue ;
185- }
186- grid.at <int >(i, j) = EMPTY_OR_SUPPRESSED;
187- }
188- }
189-
190- grid.at <int >(y, x) = KEPT;
191- keepIndices.emplace_back (idx);
192- }
193- }
194-
195- return keepIndices;
196- }
197-
19878KeyPointAndDesc processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg, float * dst, int borderRemove,
19979 float confidenceThresh, bool alignCorners, int distThresh)
20080{
@@ -204,22 +84,22 @@ KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& input
20484 osh.preprocess (dst, scaledImg.data , Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_CHANNEL);
20585 auto inferenceOutput = osh ({dst});
20686
207- std::vector<cv::KeyPoint> keyPoints = getKeyPoints (inferenceOutput, borderRemove, confidenceThresh);
87+ std::vector<cv::KeyPoint> keyPoints = osh. getKeyPoints (inferenceOutput, borderRemove, confidenceThresh);
20888
20989 std::vector<int > descriptorShape (inferenceOutput[1 ].second .begin (), inferenceOutput[1 ].second .end ());
21090 cv::Mat coarseDescriptorMat (descriptorShape.size (), descriptorShape.data (), CV_32F,
21191 inferenceOutput[1 ].first ); // 1 x 256 x H/8 x W/8
21292
213- std::vector<int > keepIndices = nmsFast (keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh);
93+ std::vector<int > keepIndices = osh. nmsFast (keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh);
21494
21595 std::vector<cv::KeyPoint> keepKeyPoints;
21696 keepKeyPoints.reserve (keepIndices.size ());
21797 std::transform (keepIndices.begin (), keepIndices.end (), std::back_inserter (keepKeyPoints),
21898 [&keyPoints](int idx) { return keyPoints[idx]; });
21999 keyPoints = std::move (keepKeyPoints);
220100
221- cv::Mat descriptors =
222- getDescriptors (coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, alignCorners);
101+ cv::Mat descriptors = osh. getDescriptors (coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H,
102+ Ort::SuperPoint::IMG_W, alignCorners);
223103
224104 for (auto & keyPoint : keyPoints) {
225105 keyPoint.pt .x *= static_cast <float >(origW) / Ort::SuperPoint::IMG_W;
0 commit comments