Skip to content

Commit cbf7e36

Browse files
committed
[REFACTOR] VerticalDetector inherits from Detector
- Changed error messages in Detector classes - Added input width check in generate functions - Reverted 'export' keyword from modelUrls.ts
1 parent 6b0dd6a commit cbf7e36

File tree

16 files changed

+131
-105
lines changed

16 files changed

+131
-105
lines changed

packages/react-native-executorch/common/rnexecutorch/models/ocr/Constants.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#pragma once
22

3+
#include <array>
34
#include <cstdint>
45
#include <opencv2/opencv.hpp>
6+
#include <vector>
57

68
namespace rnexecutorch::models::ocr::constants {
79

@@ -30,6 +32,11 @@ inline constexpr int32_t kVerticalLineThreshold = 20;
3032
inline constexpr int32_t kSmallDetectorWidth = 320;
3133
inline constexpr int32_t kMediumDetectorWidth = 800;
3234
inline constexpr int32_t kLargeDetectorWidth = 1280;
35+
inline constexpr std::array<int32_t, 3> kDetectorInputWidths = {
36+
kSmallDetectorWidth, kMediumDetectorWidth, kLargeDetectorWidth};
37+
inline constexpr std::array<int32_t, 4> kRecognizerInputWidths = {
38+
kSmallVerticalRecognizerWidth, kSmallRecognizerWidth,
39+
kMediumRecognizerWidth, kLargeRecognizerWidth};
3340

3441
/*
3542
Mean and variance values for image normalization were used in EASYOCR pipeline

packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.cpp

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
11
#include "Detector.h"
2+
#include "Constants.h"
3+
#include <cstdint>
24
#include <rnexecutorch/data_processing/ImageProcessing.h>
35
#include <rnexecutorch/models/ocr/Constants.h>
46
#include <rnexecutorch/models/ocr/utils/DetectorUtils.h>
7+
#include <stdexcept>
58
#include <string>
6-
79
namespace rnexecutorch::models::ocr {
810
Detector::Detector(const std::string &modelSource,
911
std::shared_ptr<react::CallInvoker> callInvoker)
10-
: BaseModel(modelSource, callInvoker) {}
12+
: BaseModel(modelSource, callInvoker) {
13+
14+
for (auto input_size : constants::kDetectorInputWidths) {
15+
std::string methodName = "forward_" + std::to_string(input_size);
16+
auto inputShapes = getAllInputShapes(methodName);
17+
if (inputShapes[0].size() < 2) {
18+
throw std::runtime_error(
19+
"Unexpected detector model input size for method:" + methodName +
20+
", expected "
21+
"at least 2 dimensions but got: " +
22+
std::to_string(inputShapes[0].size()) + ".");
23+
}
24+
}
25+
}
1126

1227
std::vector<types::DetectorBBox> Detector::generate(const cv::Mat &inputImage,
1328
int32_t inputWidth) {
@@ -19,43 +34,46 @@ std::vector<types::DetectorBBox> Detector::generate(const cv::Mat &inputImage,
1934
original aspect ratio and the missing parts are filled with padding.
2035
*/
2136

22-
std::string methodName = "forward_" + std::to_string(inputWidth);
37+
utils::validateInputWidth(inputWidth, constants::kDetectorInputWidths,
38+
"Detector");
2339

40+
std::string methodName = "forward_" + std::to_string(inputWidth);
2441
auto inputShapes = getAllInputShapes(methodName);
25-
if (inputShapes.empty()) {
26-
throw std::runtime_error("Detector model: invalid method name " +
27-
methodName);
28-
}
29-
30-
std::vector<int32_t> modelInputShape = inputShapes[0];
3142

32-
if (modelInputShape.size() < 2) {
33-
throw std::runtime_error("Detector model: invalid method name: " +
34-
methodName);
35-
}
36-
37-
cv::Size modelInputSize =
38-
cv::Size(modelInputShape[modelInputShape.size() - 1],
39-
modelInputShape[modelInputShape.size() - 2]);
43+
cv::Size modelInputSize = calculateModelImageSize(inputWidth);
4044

4145
cv::Mat resizedInputImage =
4246
image_processing::resizePadded(inputImage, modelInputSize);
4347
TensorPtr inputTensor = image_processing::getTensorFromMatrix(
4448
inputShapes[0], resizedInputImage, constants::kNormalizationMean,
4549
constants::kNormalizationVariance);
4650
auto forwardResult = BaseModel::execute(methodName, {inputTensor});
51+
4752
if (!forwardResult.ok()) {
4853
throw std::runtime_error(
49-
"Failed to forward, error: " +
54+
"Failed to " + methodName + " error: " +
5055
std::to_string(static_cast<uint32_t>(forwardResult.error())));
5156
}
5257

5358
return postprocess(forwardResult->at(0).toTensor(), modelInputSize);
5459
}
5560

61+
cv::Size Detector::calculateModelImageSize(int32_t methodInputWidth) {
62+
63+
utils::validateInputWidth(methodInputWidth, constants::kDetectorInputWidths,
64+
"Detector");
65+
std::string methodName = "forward_" + std::to_string(methodInputWidth);
66+
67+
auto inputShapes = getAllInputShapes(methodName);
68+
std::vector<int32_t> modelInputShape = inputShapes[0];
69+
cv::Size modelInputSize =
70+
cv::Size(modelInputShape[modelInputShape.size() - 1],
71+
modelInputShape[modelInputShape.size() - 2]);
72+
return modelInputSize;
73+
}
74+
5675
std::vector<types::DetectorBBox>
57-
Detector::postprocess(const Tensor &tensor,
58-
const cv::Size &modelInputSize) const {
76+
Detector::postprocess(const Tensor &tensor, const cv::Size &modelInputSize) {
5977
/*
6078
The output of the model consists of two matrices (heat maps):
6179
1. ScoreText(Score map) - The probability of a region containing character.

packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,20 @@ namespace rnexecutorch::models::ocr {
1717
using executorch::aten::Tensor;
1818
using executorch::extension::TensorPtr;
1919

20-
class Detector final : public BaseModel {
20+
class Detector : public BaseModel {
2121
public:
2222
explicit Detector(const std::string &modelSource,
2323
std::shared_ptr<react::CallInvoker> callInvoker);
24-
std::vector<types::DetectorBBox> generate(const cv::Mat &inputImage,
25-
int32_t inputWidth);
24+
virtual std::vector<types::DetectorBBox> generate(const cv::Mat &inputImage,
25+
int32_t inputWidth);
2626

27-
private:
28-
std::vector<types::DetectorBBox>
29-
postprocess(const Tensor &tensor, const cv::Size &modelInputSize) const;
27+
cv::Size calculateModelImageSize(int32_t methodInputWidth);
28+
29+
protected:
30+
TensorPtr runInference(const cv::Mat &inputImage, int32_t inputWidth,
31+
const std::string &detectorName);
32+
33+
std::vector<types::DetectorBBox> postprocess(const Tensor &tensor,
34+
const cv::Size &modelInputSize);
3035
};
3136
} // namespace rnexecutorch::models::ocr

packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
namespace rnexecutorch::models::ocr {
77
OCR::OCR(const std::string &detectorSource, const std::string &recognizerSource,
8-
std::string symbols, std::shared_ptr<react::CallInvoker> callInvoker)
8+
const std::string &symbols,
9+
std::shared_ptr<react::CallInvoker> callInvoker)
910
: detector(detectorSource, callInvoker),
1011
recognitionHandler(recognizerSource, symbols, callInvoker) {}
1112

packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace models::ocr {
2525
class OCR final {
2626
public:
2727
explicit OCR(const std::string &detectorSource,
28-
const std::string &recognizerSource, std::string symbols,
28+
const std::string &recognizerSource, const std::string &symbols,
2929
std::shared_ptr<react::CallInvoker> callInvoker);
3030
std::vector<types::OCRDetection> generate(std::string input);
3131
std::size_t getMemoryLowerBound() const noexcept;

packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace rnexecutorch::models::ocr {
77
RecognitionHandler::RecognitionHandler(
8-
const std::string &recognizerSource, std::string symbols,
8+
const std::string &recognizerSource, const std::string &symbols,
99
std::shared_ptr<react::CallInvoker> callInvoker)
1010
: converter(symbols), recognizer(recognizerSource, callInvoker) {
1111
memorySizeLowerBound = recognizer.getMemoryLowerBound();

packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace rnexecutorch::models::ocr {
1818
class RecognitionHandler final {
1919
public:
2020
explicit RecognitionHandler(const std::string &recognizer,
21-
std::string symbols,
21+
const std::string &symbols,
2222
std::shared_ptr<react::CallInvoker> callInvoker);
2323
std::vector<types::OCRDetection>
2424
recognize(std::vector<types::DetectorBBox> bboxesList, cv::Mat &imgGray,

packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#include "Recognizer.h"
2+
#include "Constants.h"
23
#include <numeric>
34
#include <rnexecutorch/data_processing/ImageProcessing.h>
45
#include <rnexecutorch/data_processing/Numerical.h>
56
#include <rnexecutorch/models/ocr/Constants.h>
67
#include <rnexecutorch/models/ocr/Types.h>
8+
#include <rnexecutorch/models/ocr/utils/DetectorUtils.h>
79
#include <rnexecutorch/models/ocr/utils/RecognizerUtils.h>
810
#include <string>
911

@@ -23,11 +25,14 @@ Recognizer::generate(const cv::Mat &grayImage, int32_t inputWidth) {
2325
The `generate` function as an argument accepts an image in grayscale
2426
already resized to the expected size.
2527
*/
28+
utils::validateInputWidth(inputWidth, constants::kRecognizerInputWidths,
29+
"Recognizer");
30+
2631
std::string method_name = "forward_" + std::to_string(inputWidth);
2732
auto shapes = getAllInputShapes(method_name);
2833
if (shapes.empty()) {
29-
throw std::runtime_error("Recognizer model: invalid method name " +
30-
method_name);
34+
throw std::runtime_error("Recognizer model: Input shapes for " +
35+
method_name " not found");
3136
}
3237
std::vector<int32_t> tensorDims = shapes[0];
3338
TensorPtr inputTensor =

packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,4 +707,21 @@ groupTextBoxes(std::vector<types::DetectorBBox> &boxes, float centerThreshold,
707707
return orderedSortedBoxes;
708708
}
709709

710+
void validateInputWidth(int32_t inputWidth, std::span<const int32_t> constants,
711+
std::string modelName) {
712+
auto it = std::ranges::find(constants, inputWidth);
713+
714+
if (it == constants.end()) {
715+
std::string allowed;
716+
for (size_t i = 0; i < constants.size(); ++i) {
717+
allowed +=
718+
std::to_string(constants[i]) + (i < constants.size() - 1 ? ", " : "");
719+
}
720+
721+
throw std::runtime_error("Unexpected input width for " + modelName +
722+
"! Expected [" + allowed + "] but got " +
723+
std::to_string(inputWidth) + ".");
724+
}
725+
}
726+
710727
} // namespace rnexecutorch::models::ocr::utils

packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,19 @@ groupTextBoxes(std::vector<types::DetectorBBox> &boxes, float centerThreshold,
7878
float distanceThreshold, float heightThreshold,
7979
int32_t minSideThreshold, int32_t maxSideThreshold,
8080
int32_t maxWidth);
81+
82+
/**
83+
* Validates if the provided image width is supported by the model.
84+
* * This method checks the input width against the passed allowed
85+
* widths in constants vector. If the width is not found, it
86+
* constructs a descriptive error message listing all valid options.
87+
*
88+
* @param inputWidth The width of the input image to be validated.
89+
* @param constants Vector of available input sizes.
90+
* @param modelName String with modelNames used for generating error message
91+
* @throws std::runtime_error If inputWidth is not present in the allowed
92+
* detector input widths array.
93+
*/
94+
void validateInputWidth(int32_t inputWidth, std::span<const int32_t> constants,
95+
std::string modelName);
8196
} // namespace rnexecutorch::models::ocr::utils

0 commit comments

Comments
 (0)