Skip to content

Commit 747f627

Browse files
chmjkbMateusz Kopciński
authored andcommitted
fix: mark some methods const in the BaseModel
1 parent 5106782 commit 747f627

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ BaseModel::BaseModel(const std::string &modelSource,
2929
}
3030

3131
std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
32-
int32_t index) {
32+
int32_t index) const {
3333
if (!module_) {
3434
throw std::runtime_error("Model not loaded: Cannot get input shape");
3535
}
@@ -87,7 +87,7 @@ BaseModel::getAllInputShapes(std::string methodName) const {
8787
/// to JS. It is not meant to be used within C++. If you want to call forward
8888
/// from C++ on a BaseModel, please use BaseModel::forward.
8989
std::vector<JSTensorViewOut>
90-
BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
90+
BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const {
9191
if (!module_) {
9292
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
9393
}
@@ -135,7 +135,7 @@ BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
135135
}
136136

137137
Result<executorch::runtime::MethodMeta>
138-
BaseModel::getMethodMeta(const std::string &methodName) {
138+
BaseModel::getMethodMeta(const std::string &methodName) const {
139139
if (!module_) {
140140
throw std::runtime_error("Model not loaded: Cannot get method meta!");
141141
}
@@ -160,7 +160,7 @@ BaseModel::forward(const std::vector<EValue> &input_evalues) const {
160160

161161
Result<std::vector<EValue>>
162162
BaseModel::execute(const std::string &methodName,
163-
const std::vector<EValue> &input_value) {
163+
const std::vector<EValue> &input_value) const {
164164
if (!module_) {
165165
throw std::runtime_error("Model not loaded, cannot run execute.");
166166
}
@@ -174,7 +174,7 @@ std::size_t BaseModel::getMemoryLowerBound() const noexcept {
174174
void BaseModel::unload() noexcept { module_.reset(nullptr); }
175175

176176
std::vector<int32_t>
177-
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) {
177+
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) const {
178178
auto sizes = tensor.sizes();
179179
return std::vector<int32_t>(sizes.begin(), sizes.end());
180180
}

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,20 @@ class BaseModel {
2121
std::shared_ptr<react::CallInvoker> callInvoker);
2222
std::size_t getMemoryLowerBound() const noexcept;
2323
void unload() noexcept;
24-
std::vector<int32_t> getInputShape(std::string method_name, int32_t index);
24+
std::vector<int32_t> getInputShape(std::string method_name,
25+
int32_t index) const;
2526
std::vector<std::vector<int32_t>>
2627
getAllInputShapes(std::string methodName = "forward") const;
2728
std::vector<JSTensorViewOut>
28-
forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
29+
forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const;
2930
Result<std::vector<EValue>> forward(const EValue &input_value) const;
3031
Result<std::vector<EValue>>
3132
forward(const std::vector<EValue> &input_value) const;
32-
Result<std::vector<EValue>> execute(const std::string &methodName,
33-
const std::vector<EValue> &input_value);
33+
Result<std::vector<EValue>>
34+
execute(const std::string &methodName,
35+
const std::vector<EValue> &input_value) const;
3436
Result<executorch::runtime::MethodMeta>
35-
getMethodMeta(const std::string &methodName);
37+
getMethodMeta(const std::string &methodName) const;
3638

3739
protected:
3840
// If possible, models should not use the JS runtime to keep JSI internals
@@ -44,7 +46,8 @@ class BaseModel {
4446

4547
private:
4648
std::size_t memorySizeLowerBound{0};
47-
std::vector<int32_t> getTensorShape(const executorch::aten::Tensor &tensor);
49+
std::vector<int32_t>
50+
getTensorShape(const executorch::aten::Tensor &tensor) const;
4851
};
4952
} // namespace models
5053

0 commit comments

Comments
 (0)