Skip to content

Commit 1ab849a

Browse files
committed
chore: apply review suggestions
1 parent d67b5d9 commit 1ab849a

File tree

5 files changed

+16
-22
lines changed

5 files changed

+16
-22
lines changed

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@ class BaseModel {
5555

5656
REGISTER_CONSTRUCTOR(models::BaseModel, std::string,
5757
std::shared_ptr<react::CallInvoker>);
58-
} // namespace rnexecutorch
58+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <rnexecutorch/threads/GlobalThreadPool.h>
77

88
namespace rnexecutorch::models::llm {
9+
namespace fs = std::filesystem;
910
using namespace facebook;
1011
using executorch::extension::TensorPtr;
1112
using executorch::extension::module::Module;
@@ -22,13 +23,13 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
2223
std::to_string(static_cast<int>(loadResult)));
2324
}
2425

25-
memorySizeLowerBound =
26-
std::filesystem::file_size(std::filesystem::path(modelSource)) +
27-
std::filesystem::file_size(std::filesystem::path(tokenizerSource));
26+
memorySizeLowerBound = fs::file_size(fs::path(modelSource)) +
27+
fs::file_size(fs::path(tokenizerSource));
2828

2929
// Determine the input mode
30-
auto tokensTensorShape = getInputShape("forward", 0);
31-
auto positionsTensorShape = getInputShape("forward", 1);
30+
auto inputShapes = getAllInputShapes("forward");
31+
auto &tokensTensorShape = inputShapes[0];
32+
auto &positionsTensorShape = inputShapes[1];
3233
if (tokensTensorShape.size() != 2 || positionsTensorShape.size() != 1) {
3334
throw std::runtime_error("Unsupported LLM input format");
3435
}

packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ class LLM : public BaseModel {
3939

4040
REGISTER_CONSTRUCTOR(models::llm::LLM, std::string, std::string,
4141
std::shared_ptr<react::CallInvoker>);
42-
} // namespace rnexecutorch
42+
} // namespace rnexecutorch

packages/react-native-executorch/common/runner/runner.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ static constexpr auto kUseKVCache = "use_kv_cache";
4747
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
4848
} // namespace
4949

50-
Runner::Runner(Module *module, const std::string &model_path,
51-
const std::string &tokenizer_path,
50+
Runner::Runner(Module *module, const std::string &tokenizer_path,
5251
const bool extended_input_mode, const float temperature,
5352
std::optional<const std::string> data_path)
5453
: module_(module), temperature_(temperature),
@@ -58,10 +57,7 @@ Runner::Runner(Module *module, const std::string &model_path,
5857
{kMaxContextLen, 128},
5958
{kUseKVCache, true},
6059
{kUseSDPAWithKVCache, false},
61-
}) {
62-
ET_LOG(Info, "Creating LLM runner: model_path=%s, tokenizer_path=%s",
63-
model_path.c_str(), tokenizer_path.c_str());
64-
}
60+
}) {}
6561

6662
bool Runner::is_loaded() const {
6763
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&

packages/react-native-executorch/common/runner/runner.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,10 @@ namespace example {
2929

3030
class Runner : public executorch::extension::llm::IRunner {
3131
public:
32-
explicit Runner(
33-
::executorch::extension::Module *module,
34-
const std::string &model_path, // TODO: consider removing this arg since
35-
// it is only used for debug purposes
36-
const std::string &tokenizer_path, const bool extended_input_mode = false,
37-
const float temperature = 0.8f,
38-
std::optional<const std::string> data_path = std::nullopt);
32+
explicit Runner(::executorch::extension::Module *module,
33+
const std::string &tokenizer_path,
34+
bool extended_input_mode = false, float temperature = 0.8f,
35+
std::optional<const std::string> data_path = std::nullopt);
3936

4037
bool is_loaded() const;
4138
::executorch::runtime::Error load();
@@ -46,7 +43,7 @@ class Runner : public executorch::extension::llm::IRunner {
4643
stats_callback = {},
4744
bool echo = true, bool warming = false);
4845
::executorch::runtime::Error warmup(const std::string &prompt);
49-
void set_extended_input_mode(bool extend_position_input);
46+
void set_extended_input_mode(bool extend_position_input) noexcept;
5047
void set_count_interval(size_t count_interval);
5148
void set_time_interval(size_t time_interval);
5249
void stop();
@@ -72,4 +69,4 @@ class Runner : public executorch::extension::llm::IRunner {
7269
text_token_generator_;
7370
};
7471

75-
} // namespace example
72+
} // namespace example

0 commit comments

Comments
 (0)