Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified packages/react-native-executorch/android/libs/classes.jar
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <rnexecutorch/threads/GlobalThreadPool.h>

namespace rnexecutorch::models::llm {
namespace llm = ::executorch::extension::llm;
namespace fs = std::filesystem;
using namespace facebook;
using executorch::extension::TensorPtr;
Expand All @@ -14,8 +15,8 @@ using executorch::runtime::Error;
LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
std::shared_ptr<react::CallInvoker> callInvoker)
: BaseModel(modelSource, callInvoker, Module::LoadMode::File),
runner(std::make_unique<example::Runner>(module_.get(), tokenizerSource,
false)) {
runner(
std::make_unique<example::Runner>(module_.get(), tokenizerSource)) {
auto loadResult = runner->load();
if (loadResult != Error::Ok) {
throw std::runtime_error("Failed to load LLM runner, error code: " +
Expand All @@ -24,20 +25,9 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,

memorySizeLowerBound = fs::file_size(fs::path(modelSource)) +
fs::file_size(fs::path(tokenizerSource));

// Determine the input mode
auto inputShapes = getAllInputShapes("forward");
auto &tokensTensorShape = inputShapes[0];
auto &positionsTensorShape = inputShapes[1];
if (tokensTensorShape.size() != 2 || positionsTensorShape.size() != 1) {
throw std::runtime_error("Unsupported LLM input format");
}
if (positionsTensorShape[0] != 1 &&
tokensTensorShape[1] == positionsTensorShape[0]) {
runner->set_extended_input_mode(true);
}
}

// TODO: add a way to manipulate the generation config with params
void LLM::generate(std::string input, std::shared_ptr<jsi::Function> callback) {
if (!runner || !runner->is_loaded()) {
throw std::runtime_error("Runner is not loaded");
Expand All @@ -50,7 +40,8 @@ void LLM::generate(std::string input, std::shared_ptr<jsi::Function> callback) {
});
};

auto error = runner->generate(input, nativeCallback, {}, false);
auto config = llm::GenerationConfig{.echo = false, .warming = false};
auto error = runner->generate(input, config, nativeCallback, {});
if (error != executorch::runtime::Error::Ok) {
throw std::runtime_error("Failed to generate text, error code: " +
std::to_string(static_cast<int>(error)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ class LLM : public BaseModel {

private:
std::unique_ptr<example::Runner> runner;

// A typical input for parallel processing in exported LLM model consists of 2
// tensors of shapes [1, N] and [1], where N is the number of tokens. Hovewer,
// some exported models require inputs of shapes [1, N] and [N], which needs
// to be marked before using LLM runner.
bool extended_input_mode_ = false;
Comment on lines -33 to -38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is removed, is it not needed anymore for Gemma?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExecuTorch 1.0 introduces its own implementation of this mechanism with the new TextDecoderRunner and specifically populate_start_pos_or_cache_position() function, so our extra code is no longer needed.

};
} // namespace models::llm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor,
size_t numClass, size_t size,
std::vector<float> &resultVector,
size_t startIdx) {
const auto* rawData = tensor.const_data_ptr<float>();
const auto *rawData = tensor.const_data_ptr<float>();
for (size_t i = 0; i < size; i++) {
resultVector[startIdx + i] = rawData[numClass * i];
}
Expand Down
44 changes: 44 additions & 0 deletions packages/react-native-executorch/common/runner/arange_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "arange_util.h"

namespace torch::executor::native {
#define ET_ARANGE_IMPL(ctx, start, numel, step, out, op_name) \
ET_SWITCH_REALHBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE, [&]() { \
auto out_data = out.mutable_data_ptr<CTYPE>(); \
for (executorch::aten::SizesType i = 0; i < numel; ++i) { \
out_data[i] = static_cast<CTYPE>(start + i * step); \
} \
})

executorch::aten::SizesType compute_arange_out_size(double start, double end,
double step) {
executorch::aten::SizesType numel =
static_cast<executorch::aten::SizesType>(std::ceil((end - start) / step));

ET_CHECK_MSG(numel >= 0,
"numel should be non-negative, but got (%" PRId64
"). start (%f), end (%f), step (%f)",
static_cast<int64_t>(numel), start, end, step);
return numel;
}

void arange_out_impl(KernelRuntimeContext &ctx, double start, double end,
double step, Tensor &out) {
(void)ctx;
executorch::aten::SizesType numel = compute_arange_out_size(start, end, step);
ET_ARANGE_IMPL(ctx, start, numel, step, out, "arange.start_out");
}

void arange_out_impl(KernelRuntimeContext &ctx, double end, Tensor &out) {
(void)ctx;
ET_ARANGE_IMPL(ctx, 0.0, end, 1.0, out, "arange.out");
}

} // namespace torch::executor::native
37 changes: 37 additions & 0 deletions packages/react-native-executorch/common/runner/arange_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include "kernel_includes.h"

namespace torch::executor::native {

executorch::aten::SizesType compute_arange_out_size(double start, double end,
double step);

inline executorch::aten::SizesType compute_arange_out_size(double end) {
return compute_arange_out_size(0.0, end, 1.0);
}

void arange_out_impl(KernelRuntimeContext &ctx, double start, double end,
double step, Tensor &out);

void arange_out_impl(KernelRuntimeContext &ctx, double end, Tensor &out);

inline void arange_out_impl(double start, double end, double step,
Tensor &out) {
KernelRuntimeContext ctx;
arange_out_impl(ctx, start, end, step, out);
}

inline void arange_out_impl(double end, Tensor &out) {
KernelRuntimeContext ctx;
arange_out_impl(ctx, 0.0, end, 1.0, out);
}
} // namespace torch::executor::native
28 changes: 28 additions & 0 deletions packages/react-native-executorch/common/runner/constants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
// constants for LLM runtime
namespace executorch::extension::llm {

// Runtime metadata key constants
inline constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
inline constexpr auto kBosId = "get_bos_id";
inline constexpr auto kEosIds = "get_eos_ids";
inline constexpr auto kMaxSeqLen = "get_max_seq_len";
inline constexpr auto kMaxContextLen = "get_max_context_len";
inline constexpr auto kVocabSize = "get_vocab_size";
inline constexpr auto kUseKVCache = "use_kv_cache";
inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";

// Multimodal method name conventions
inline constexpr auto kVisionEncoderMethod = "vision_encoder";
inline constexpr auto kAudioEncoderMethod = "audio_encoder";
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
inline constexpr auto kTextModelMethod = "text_decoder";

} // namespace executorch::extension::llm
Loading
Loading