Skip to content

Commit e592b72

Browse files
committed
[REFACTOR] VerticalDetector inherits from Detector
- Changed error messages in Detector classes - Added input width check in generate functions - Reverted 'export' keyword from modelUrls.ts
1 parent 6b0dd6a commit e592b72

File tree

15 files changed

+160
-89
lines changed

15 files changed

+160
-89
lines changed

packages/react-native-executorch/common/rnexecutorch/models/ocr/Constants.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#pragma once
22

3+
#include <array>
34
#include <cstdint>
45
#include <opencv2/opencv.hpp>
6+
#include <vector>
57

68
namespace rnexecutorch::models::ocr::constants {
79

@@ -30,6 +32,11 @@ inline constexpr int32_t kVerticalLineThreshold = 20;
3032
inline constexpr int32_t kSmallDetectorWidth = 320;
3133
inline constexpr int32_t kMediumDetectorWidth = 800;
3234
inline constexpr int32_t kLargeDetectorWidth = 1280;
35+
inline constexpr std::array<int32_t, 3> kDetectorInputWidths = {
36+
kSmallDetectorWidth, kMediumDetectorWidth, kLargeDetectorWidth};
37+
inline constexpr std::array<int32_t, 4> kRecognizerInputWidths = {
38+
kSmallVerticalRecognizerWidth, kSmallRecognizerWidth,
39+
kMediumRecognizerWidth, kLargeRecognizerWidth};
3340

3441
/*
3542
Mean and variance values for image normalization were used in EASYOCR pipeline

packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.cpp

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,34 @@
11
#include "Detector.h"
2+
#include "Constants.h"
3+
#include <cstdint>
24
#include <rnexecutorch/data_processing/ImageProcessing.h>
35
#include <rnexecutorch/models/ocr/Constants.h>
46
#include <rnexecutorch/models/ocr/utils/DetectorUtils.h>
7+
#include <stdexcept>
58
#include <string>
6-
79
namespace rnexecutorch::models::ocr {
810
Detector::Detector(const std::string &modelSource,
911
std::shared_ptr<react::CallInvoker> callInvoker)
10-
: BaseModel(modelSource, callInvoker) {}
12+
: BaseModel(modelSource, callInvoker) {
13+
14+
for (auto input_size : constants::kDetectorInputWidths) {
15+
std::string methodName = "forward_" + std::to_string(input_size);
16+
auto inputShapes = getAllInputShapes(methodName);
17+
if (inputShapes[0].size() < 2) {
18+
throw std::runtime_error(
19+
"Unexpected detector model input size for method:" + methodName +
20+
", expected "
21+
"at least 2 dimensions but got: " +
22+
std::to_string(inputShapes[0].size()) + ".");
23+
}
24+
}
25+
this->modelSmallImageSize =
26+
calculateImageSizeForWidth(constants::kSmallDetectorWidth);
27+
this->modelMediumImageSize =
28+
calculateImageSizeForWidth(constants::kMediumDetectorWidth);
29+
this->modelLargeImageSize =
30+
calculateImageSizeForWidth(constants::kLargeDetectorWidth);
31+
}
1132

1233
std::vector<types::DetectorBBox> Detector::generate(const cv::Mat &inputImage,
1334
int32_t inputWidth) {
@@ -19,40 +40,46 @@ std::vector<types::DetectorBBox> Detector::generate(const cv::Mat &inputImage,
1940
original aspect ratio and the missing parts are filled with padding.
2041
*/
2142

22-
std::string methodName = "forward_" + std::to_string(inputWidth);
43+
utils::validateInputWidth(inputWidth, constants::kDetectorInputWidths,
44+
"Detector");
2345

46+
std::string methodName = "forward_" + std::to_string(inputWidth);
2447
auto inputShapes = getAllInputShapes(methodName);
25-
if (inputShapes.empty()) {
26-
throw std::runtime_error("Detector model: invalid method name " +
27-
methodName);
28-
}
29-
30-
std::vector<int32_t> modelInputShape = inputShapes[0];
31-
32-
if (modelInputShape.size() < 2) {
33-
throw std::runtime_error("Detector model: invalid method name: " +
34-
methodName);
35-
}
3648

37-
cv::Size modelInputSize =
38-
cv::Size(modelInputShape[modelInputShape.size() - 1],
39-
modelInputShape[modelInputShape.size() - 2]);
49+
cv::Size modelInputSize = getModelImageSize(inputWidth);
4050

4151
cv::Mat resizedInputImage =
4252
image_processing::resizePadded(inputImage, modelInputSize);
4353
TensorPtr inputTensor = image_processing::getTensorFromMatrix(
4454
inputShapes[0], resizedInputImage, constants::kNormalizationMean,
4555
constants::kNormalizationVariance);
4656
auto forwardResult = BaseModel::execute(methodName, {inputTensor});
57+
4758
if (!forwardResult.ok()) {
4859
throw std::runtime_error(
49-
"Failed to forward, error: " +
60+
"Failed to " + methodName + " error: " +
5061
std::to_string(static_cast<uint32_t>(forwardResult.error())));
5162
}
5263

5364
return postprocess(forwardResult->at(0).toTensor(), modelInputSize);
5465
}
5566

67+
cv::Size Detector::getModelImageSize(int inputWidth) const noexcept {
68+
switch (inputWidth) {
69+
case constants::kSmallDetectorWidth:
70+
return modelSmallImageSize;
71+
break;
72+
case constants::kMediumDetectorWidth:
73+
return modelMediumImageSize;
74+
break;
75+
case constants::kLargeDetectorWidth:
76+
return modelLargeImageSize;
77+
break;
78+
default:
79+
return modelMediumImageSize;
80+
}
81+
}
82+
5683
std::vector<types::DetectorBBox>
5784
Detector::postprocess(const Tensor &tensor,
5885
const cv::Size &modelInputSize) const {
@@ -103,4 +130,28 @@ Detector::postprocess(const Tensor &tensor,
103130
return bBoxesList;
104131
}
105132

133+
cv::Size Detector::calculateImageSizeForWidth(const int methoInputWidth) {
134+
135+
std::string methodName = "forward_" + std::to_string(methoInputWidth);
136+
137+
auto inputShapes = getAllInputShapes(methodName);
138+
139+
if (inputShapes.empty()) {
140+
throw std::runtime_error("Detector model has no input shape for method: " +
141+
methodName);
142+
}
143+
std::vector<int32_t> modelInputShape = inputShapes[0];
144+
145+
if (modelInputShape.size() < 2) {
146+
throw std::runtime_error("Unexpected detector model input size, expected "
147+
"at least 2 dimensions but got: " +
148+
std::to_string(modelInputShape.size()) + ".");
149+
}
150+
151+
cv::Size modelInputSize =
152+
cv::Size(modelInputShape[modelInputShape.size() - 1],
153+
modelInputShape[modelInputShape.size() - 2]);
154+
return modelInputSize;
155+
}
156+
106157
} // namespace rnexecutorch::models::ocr

packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,25 @@ namespace rnexecutorch::models::ocr {
1717
using executorch::aten::Tensor;
1818
using executorch::extension::TensorPtr;
1919

20-
class Detector final : public BaseModel {
20+
class Detector : public BaseModel {
2121
public:
2222
explicit Detector(const std::string &modelSource,
2323
std::shared_ptr<react::CallInvoker> callInvoker);
24-
std::vector<types::DetectorBBox> generate(const cv::Mat &inputImage,
25-
int32_t inputWidth);
24+
virtual std::vector<types::DetectorBBox> generate(const cv::Mat &inputImage,
25+
int32_t inputWidth);
26+
27+
cv::Size getModelImageSize(int inputWidth) const noexcept;
28+
29+
protected:
30+
cv::Size calculateImageSizeForWidth(const int methoInputWidth);
31+
32+
cv::Size modelSmallImageSize;
33+
cv::Size modelMediumImageSize;
34+
cv::Size modelLargeImageSize;
35+
36+
TensorPtr runInference(const cv::Mat &inputImage, int32_t inputWidth,
37+
const std::string &detectorName);
2638

27-
private:
2839
std::vector<types::DetectorBBox>
2940
postprocess(const Tensor &tensor, const cv::Size &modelInputSize) const;
3041
};

packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
namespace rnexecutorch::models::ocr {
77
OCR::OCR(const std::string &detectorSource, const std::string &recognizerSource,
8-
std::string symbols, std::shared_ptr<react::CallInvoker> callInvoker)
8+
const std::string &symbols,
9+
std::shared_ptr<react::CallInvoker> callInvoker)
910
: detector(detectorSource, callInvoker),
1011
recognitionHandler(recognizerSource, symbols, callInvoker) {}
1112

packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace models::ocr {
2525
class OCR final {
2626
public:
2727
explicit OCR(const std::string &detectorSource,
28-
const std::string &recognizerSource, std::string symbols,
28+
const std::string &recognizerSource, const std::string &symbols,
2929
std::shared_ptr<react::CallInvoker> callInvoker);
3030
std::vector<types::OCRDetection> generate(std::string input);
3131
std::size_t getMemoryLowerBound() const noexcept;

packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace rnexecutorch::models::ocr {
77
RecognitionHandler::RecognitionHandler(
8-
const std::string &recognizerSource, std::string symbols,
8+
const std::string &recognizerSource, const std::string &symbols,
99
std::shared_ptr<react::CallInvoker> callInvoker)
1010
: converter(symbols), recognizer(recognizerSource, callInvoker) {
1111
memorySizeLowerBound = recognizer.getMemoryLowerBound();

packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace rnexecutorch::models::ocr {
1818
class RecognitionHandler final {
1919
public:
2020
explicit RecognitionHandler(const std::string &recognizer,
21-
std::string symbols,
21+
const std::string &symbols,
2222
std::shared_ptr<react::CallInvoker> callInvoker);
2323
std::vector<types::OCRDetection>
2424
recognize(std::vector<types::DetectorBBox> bboxesList, cv::Mat &imgGray,

packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#include "Recognizer.h"
2+
#include "Constants.h"
23
#include <numeric>
34
#include <rnexecutorch/data_processing/ImageProcessing.h>
45
#include <rnexecutorch/data_processing/Numerical.h>
56
#include <rnexecutorch/models/ocr/Constants.h>
67
#include <rnexecutorch/models/ocr/Types.h>
8+
#include <rnexecutorch/models/ocr/utils/DetectorUtils.h>
79
#include <rnexecutorch/models/ocr/utils/RecognizerUtils.h>
810
#include <string>
911

@@ -23,6 +25,9 @@ Recognizer::generate(const cv::Mat &grayImage, int32_t inputWidth) {
2325
The `generate` function as an argument accepts an image in grayscale
2426
already resized to the expected size.
2527
*/
28+
utils::validateInputWidth(inputWidth, constants::kRecognizerInputWidths,
29+
"Recognizer");
30+
2631
std::string method_name = "forward_" + std::to_string(inputWidth);
2732
auto shapes = getAllInputShapes(method_name);
2833
if (shapes.empty()) {

packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,4 +707,21 @@ groupTextBoxes(std::vector<types::DetectorBBox> &boxes, float centerThreshold,
707707
return orderedSortedBoxes;
708708
}
709709

710+
void validateInputWidth(int32_t inputWidth, std::span<const int32_t> constants,
711+
std::string modelName) {
712+
auto it = std::ranges::find(constants, inputWidth);
713+
714+
if (it == constants.end()) {
715+
std::string allowed;
716+
for (size_t i = 0; i < constants.size(); ++i) {
717+
allowed +=
718+
std::to_string(constants[i]) + (i < constants.size() - 1 ? ", " : "");
719+
}
720+
721+
throw std::runtime_error("Unexpected input width for " + modelName +
722+
"! Expected [" + allowed + "] but got " +
723+
std::to_string(inputWidth) + ".");
724+
}
725+
}
726+
710727
} // namespace rnexecutorch::models::ocr::utils

packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,19 @@ groupTextBoxes(std::vector<types::DetectorBBox> &boxes, float centerThreshold,
7878
float distanceThreshold, float heightThreshold,
7979
int32_t minSideThreshold, int32_t maxSideThreshold,
8080
int32_t maxWidth);
81+
82+
/**
83+
* Validates if the provided image width is supported by the model.
84+
* * This method checks the input width against the passed allowed
85+
* widths in constants vector. If the width is not found, it
86+
* constructs a descriptive error message listing all valid options.
87+
*
88+
* @param inputWidth The width of the input image to be validated.
89+
* @param constants Vector of available input sizes.
90+
* @param modelName String with modelNames used for generating error message
91+
* @throws std::runtime_error If inputWidth is not present in the allowed
92+
* detector input widths array.
93+
*/
94+
void validateInputWidth(int32_t inputWidth, std::span<const int32_t> constants,
95+
std::string modelName);
8196
} // namespace rnexecutorch::models::ocr::utils

0 commit comments

Comments
 (0)