|
| 1 | +/** |
| 2 | + * @file SuperPointApp.cpp |
| 3 | + * |
| 4 | + * @author btran |
| 5 | + * |
| 6 | + */ |
| 7 | + |
| 8 | +#include <iostream> |
| 9 | + |
| 10 | +#include "SuperPoint.hpp" |
| 11 | +#include "Utility.hpp" |
| 12 | +#include <opencv2/opencv.hpp> |
| 13 | +#include <stdexcept> |
| 14 | + |
| 15 | +namespace |
| 16 | +{ |
| 17 | +std::pair<std::vector<cv::KeyPoint>, cv::Mat> processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg, |
| 18 | + float* dst, int borderRemove = 4, |
| 19 | + float confidenceThresh = 0.015, bool alignCorners = true); |
| 20 | +} // namespace |
| 21 | + |
| 22 | +int main(int argc, char* argv[]) |
| 23 | +{ |
| 24 | + if (argc != 4) { |
| 25 | + std::cerr << "Usage: [apps] [path/to/onnx/super/point] [path/to/image1] [path/to/image2]" << std::endl; |
| 26 | + return EXIT_FAILURE; |
| 27 | + } |
| 28 | + |
| 29 | + const std::string ONNX_MODEL_PATH = argv[1]; |
| 30 | + const std::vector<std::string> IMAGE_PATHS = {argv[2], argv[3]}; |
| 31 | + |
| 32 | + Ort::SuperPoint osh(ONNX_MODEL_PATH, 0, |
| 33 | + std::vector<std::vector<int64_t>>{ |
| 34 | + {1, Ort::SuperPoint::IMG_CHANNEL, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W}}); |
| 35 | + |
| 36 | + std::vector<cv::Mat> images; |
| 37 | + std::vector<cv::Mat> grays; |
| 38 | + std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(images), |
| 39 | + [](const auto& imagePath) { return cv::imread(imagePath); }); |
| 40 | + for (int i = 0; i < 2; ++i) { |
| 41 | + if (images[i].empty()) { |
| 42 | + throw std::runtime_error("failed to open " + IMAGE_PATHS[i]); |
| 43 | + } |
| 44 | + } |
| 45 | + std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(grays), |
| 46 | + [](const auto& imagePath) { return cv::imread(imagePath, 0); }); |
| 47 | + |
| 48 | + std::vector<float> dst(Ort::SuperPoint::IMG_CHANNEL * Ort::SuperPoint::IMG_H * Ort::SuperPoint::IMG_W); |
| 49 | + std::vector<std::pair<std::vector<cv::KeyPoint>, cv::Mat>> results; |
| 50 | + std::transform(grays.begin(), grays.end(), std::back_inserter(results), |
| 51 | + [&osh, &dst](const auto& gray) { return processOneFrame(osh, gray, dst.data()); }); |
| 52 | + |
| 53 | + return EXIT_SUCCESS; |
| 54 | +} |
| 55 | + |
| 56 | +namespace |
| 57 | +{ |
| 58 | +std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput, |
| 59 | + int borderRemove, float confidenceThresh) |
| 60 | +{ |
| 61 | + std::vector<int> detectorShape(inferenceOutput[0].second.begin() + 1, inferenceOutput[0].second.end()); |
| 62 | + |
| 63 | + cv::Mat detectorMat(detectorShape.size(), detectorShape.data(), CV_32F, |
| 64 | + inferenceOutput[0].first); // 65 x H/8 x W/8 |
| 65 | + cv::Mat buffer; |
| 66 | + |
| 67 | + transposeNDWrapper(detectorMat, {1, 2, 0}, buffer); |
| 68 | + buffer.copyTo(detectorMat); // H/8 x W/8 x 65 |
| 69 | + |
| 70 | + for (int i = 0; i < detectorShape[1]; ++i) { |
| 71 | + for (int j = 0; j < detectorShape[2]; ++j) { |
| 72 | + Ort::softmax(detectorMat.ptr<float>(i, j), detectorShape[0]); |
| 73 | + } |
| 74 | + } |
| 75 | + detectorMat = detectorMat({cv::Range::all(), cv::Range::all(), cv::Range(0, detectorShape[0] - 1)}) |
| 76 | + .clone(); // H/8 x W/8 x 64 |
| 77 | + detectorMat = detectorMat.reshape(1, {detectorShape[1], detectorShape[2], 8, 8}); // H/8 x W/8 x 8 x 8 |
| 78 | + transposeNDWrapper(detectorMat, {0, 2, 1, 3}, buffer); |
| 79 | + buffer.copyTo(detectorMat); // H/8 x 8 x W/8 x 8 |
| 80 | + |
| 81 | + detectorMat = detectorMat.reshape(1, {detectorShape[1] * 8, detectorShape[2] * 8}); // H x W |
| 82 | + |
| 83 | + std::vector<cv::KeyPoint> keyPoints; |
| 84 | + for (int i = borderRemove; i < detectorMat.rows - borderRemove; ++i) { |
| 85 | + auto rowPtr = detectorMat.ptr<float>(i); |
| 86 | + for (int j = borderRemove; j < detectorMat.cols - borderRemove; ++j) { |
| 87 | + if (rowPtr[j] > confidenceThresh) { |
| 88 | + cv::KeyPoint keyPoint; |
| 89 | + keyPoint.pt.x = j; |
| 90 | + keyPoint.pt.y = i; |
| 91 | + keyPoint.response = rowPtr[j]; |
| 92 | + keyPoints.emplace_back(keyPoint); |
| 93 | + } |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + return keyPoints; |
| 98 | +} |
| 99 | + |
| 100 | +cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height, |
| 101 | + int width, bool alignCorners) |
| 102 | +{ |
| 103 | + cv::Mat keyPointMat(keyPoints.size(), 2, CV_32F); |
| 104 | + |
| 105 | + for (int i = 0; i < keyPoints.size(); ++i) { |
| 106 | + auto rowPtr = keyPointMat.ptr<float>(i); |
| 107 | + rowPtr[0] = 2 * keyPoints[i].pt.y / (height - 1) - 1; |
| 108 | + rowPtr[1] = 2 * keyPoints[i].pt.x / (width - 1) - 1; |
| 109 | + } |
| 110 | + keyPointMat = keyPointMat.reshape(1, {1, 1, static_cast<int>(keyPoints.size()), 2}); |
| 111 | + cv::Mat descriptors = bilinearGridSample(coarseDescriptors, keyPointMat, alignCorners); |
| 112 | + |
| 113 | + return descriptors.reshape(1, {coarseDescriptors.size[1], static_cast<int>(keyPoints.size())}); |
| 114 | +} |
| 115 | + |
| 116 | +std::pair<std::vector<cv::KeyPoint>, cv::Mat> processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg, |
| 117 | + float* dst, int borderRemove, float confidenceThresh, |
| 118 | + bool alignCorners) |
| 119 | +{ |
| 120 | + int origW = inputImg.cols, origH = inputImg.rows; |
| 121 | + std::vector<float> originImageSize{static_cast<float>(origH), static_cast<float>(origW)}; |
| 122 | + cv::Mat scaledImg; |
| 123 | + cv::resize(inputImg, scaledImg, cv::Size(Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H), 0, 0, cv::INTER_CUBIC); |
| 124 | + osh.preprocess(dst, scaledImg.data, Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_CHANNEL); |
| 125 | + auto inferenceOutput = osh({dst}); |
| 126 | + |
| 127 | + std::vector<cv::KeyPoint> keyPoints = getKeyPoints(inferenceOutput, borderRemove, confidenceThresh); |
| 128 | + |
| 129 | + std::vector<int> descriptorShape(inferenceOutput[1].second.begin(), inferenceOutput[1].second.end()); |
| 130 | + cv::Mat coarseDescriptorMat(descriptorShape.size(), descriptorShape.data(), CV_32F, |
| 131 | + inferenceOutput[1].first); // 1 x 256 x H/8 x W/8 |
| 132 | + |
| 133 | + cv::Mat descriptors = |
| 134 | + getDescriptors(coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, alignCorners); |
| 135 | + |
| 136 | + for (auto& keyPoint : keyPoints) { |
| 137 | + keyPoint.pt.x *= static_cast<float>(origW) / Ort::SuperPoint::IMG_W; |
| 138 | + keyPoint.pt.y *= static_cast<float>(origH) / Ort::SuperPoint::IMG_H; |
| 139 | + } |
| 140 | + |
| 141 | + return {keyPoints, descriptors}; |
| 142 | +} |
| 143 | +} // namespace |
0 commit comments