Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified android/libs/executorch-llama.aar
Binary file not shown.
97 changes: 97 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package com.swmansion.rnexecutorch

import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.Fetcher
import com.swmansion.rnexecutorch.utils.ProgressResponseBody
import com.swmansion.rnexecutorch.utils.ResourceType
import com.swmansion.rnexecutorch.utils.TensorUtils
import okhttp3.OkHttpClient
import org.pytorch.executorch.Module
import org.pytorch.executorch.Tensor
import java.net.URL

class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) {
private lateinit var module: Module
private val client = OkHttpClient()

override fun getName(): String {
return NAME
}

private fun downloadModel(
url: URL, resourceType: ResourceType, callback: (path: String?, error: Exception?) -> Unit
) {
Fetcher.downloadResource(reactApplicationContext,
client,
url,
resourceType,
{ path, error -> callback(path, error) },
object : ProgressResponseBody.ProgressListener {
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) {
}
})
}

override fun loadModule(modelPath: String, promise: Promise) {
try {
downloadModel(
URL(modelPath), ResourceType.MODEL
) { path, error ->
if (error != null) {
promise.reject(error.message!!, "-1")
return@downloadModel
}

module = Module.load(path)
promise.resolve(0)
return@downloadModel
}
} catch (e: Exception) {
promise.reject(e.message!!, "-1")
}
}

override fun loadMethod(methodName: String, promise: Promise) {
val result = module.loadMethod(methodName)
if (result != 0) {
promise.reject("Method loading failed", result.toString())
return
}

promise.resolve(result)
}

override fun forward(
input: ReadableArray,
shape: ReadableArray,
inputType: Double,
promise: Promise
) {
try {
val executorchInput =
TensorUtils.getExecutorchInput(input, ArrayUtils.createLongArray(shape), inputType.toInt())

lateinit var result: Tensor
module.forward(executorchInput)[0].toTensor().also { result = it }

promise.resolve(ArrayUtils.createReadableArray(result))
return
} catch (e: IllegalArgumentException) {
//The error is thrown when transformation to Tensor fails
promise.reject("Forward Failed Execution", "18")
return
} catch (e: Exception) {
//Executorch forward method throws an exception with a message: "Method forward failed with code XX"
val exceptionCode = e.message!!.substring(e.message!!.length - 2)
promise.reject("Forward Failed Execution", exceptionCode)
return
}
}

companion object {
const val NAME = "ETModule"
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
package com.swmansion.rnexecutorch

import com.facebook.react.bridge.ReactApplicationContext
import android.os.Build
import android.util.Log
import androidx.annotation.RequiresApi
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactContextBaseJavaModule
import com.facebook.react.bridge.ReactMethod
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.Fetcher
import com.swmansion.rnexecutorch.utils.ProgressResponseBody
import com.swmansion.rnexecutorch.utils.ResourceType
import com.swmansion.rnexecutorch.utils.llms.ChatRole
import com.swmansion.rnexecutorch.utils.llms.ConversationManager
import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN
import okhttp3.OkHttpClient
import okhttp3.Request
import org.pytorch.executorch.LlamaModule
import org.pytorch.executorch.LlamaCallback
import java.io.File
import org.pytorch.executorch.LlamaModule
import java.net.URL

class RnExecutorchModule(reactContext: ReactApplicationContext) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,40 @@ import com.facebook.react.module.model.ReactModuleInfo
import com.facebook.react.module.model.ReactModuleInfoProvider
import com.facebook.react.uimanager.ViewManager


class RnExecutorchPackage : TurboReactPackage() {
override fun createViewManagers(reactContext: ReactApplicationContext): List<ViewManager<*, *>> {
return listOf()
}

override fun getModule(name: String, reactContext: ReactApplicationContext): NativeModule? =
if (name == RnExecutorchModule.NAME) {
RnExecutorchModule(reactContext)
} else {
null
}
override fun getModule(name: String, reactContext: ReactApplicationContext): NativeModule? =
if (name == RnExecutorchModule.NAME) {
RnExecutorchModule(reactContext)
} else if (name == ETModule.NAME) {
ETModule(reactContext)
} else {
null
}

override fun getReactModuleInfoProvider(): ReactModuleInfoProvider {
return ReactModuleInfoProvider {
override fun getReactModuleInfoProvider(): ReactModuleInfoProvider {
return ReactModuleInfoProvider {
val moduleInfos: MutableMap<String, ReactModuleInfo> = HashMap()
moduleInfos[RnExecutorchModule.NAME] = ReactModuleInfo(
RnExecutorchModule.NAME,
RnExecutorchModule.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
true, // hasConstants
false, // isCxxModule
true // isTurboModule
true,
)
moduleInfos[ETModule.NAME] = ReactModuleInfo(
ETModule.NAME,
ETModule.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package com.swmansion.rnexecutorch.utils

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.ReadableArray
import org.pytorch.executorch.DType
import org.pytorch.executorch.Tensor

class ArrayUtils {
companion object {
fun createByteArray(input: ReadableArray): ByteArray {
val byteArray = ByteArray(input.size())
for (i in 0 until input.size()) {
byteArray[i] = input.getInt(i).toByte()
}
return byteArray
}

fun createIntArray(input: ReadableArray): IntArray {
val intArray = IntArray(input.size())
for (i in 0 until input.size()) {
intArray[i] = input.getInt(i)
}
return intArray
}

fun createFloatArray(input: ReadableArray): FloatArray {
val floatArray = FloatArray(input.size())
for (i in 0 until input.size()) {
floatArray[i] = input.getDouble(i).toFloat()
}
return floatArray
}

fun createLongArray(input: ReadableArray): LongArray {
val longArray = LongArray(input.size())
for (i in 0 until input.size()) {
longArray[i] = input.getInt(i).toLong()
}
return longArray
}

fun createDoubleArray(input: ReadableArray): DoubleArray {
val doubleArray = DoubleArray(input.size())
for (i in 0 until input.size()) {
doubleArray[i] = input.getDouble(i)
}
return doubleArray
}

fun createReadableArray(result: Tensor): ReadableArray {
val resultArray = Arguments.createArray()
when (result.dtype()) {
DType.UINT8 -> {
val byteArray = result.dataAsByteArray
for (i in byteArray) {
resultArray.pushInt(i.toInt())
}
}

DType.INT32 -> {
val intArray = result.dataAsIntArray
for (i in intArray) {
resultArray.pushInt(i)
}
}

DType.FLOAT -> {
val longArray = result.dataAsFloatArray
for (i in longArray) {
resultArray.pushDouble(i.toDouble())
}
}

DType.DOUBLE -> {
val floatArray = result.dataAsDoubleArray
for (i in floatArray) {
resultArray.pushDouble(i)
}
}

DType.INT64 -> {
val doubleArray = result.dataAsLongArray
for (i in doubleArray) {
resultArray.pushLong(i)
}
}

else -> {
throw IllegalArgumentException("Invalid dtype: ${result.dtype()}")
}
}

return resultArray
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.swmansion.rnexecutorch
package com.swmansion.rnexecutorch.utils

import android.content.Context
import okhttp3.Call
Expand Down Expand Up @@ -113,11 +113,17 @@ class Fetcher {

private fun resolveConfigUrlFromModelUrl(modelUrl: URL): URL {
// Create a new URL using the base URL and append the desired path
val baseUrl = modelUrl.protocol + "://" + modelUrl.host + modelUrl.path.substringBefore("resolve/")
val baseUrl =
modelUrl.protocol + "://" + modelUrl.host + modelUrl.path.substringBefore("resolve/")
return URL(baseUrl + "resolve/main/config.json")
}

private fun sendRequestToUrl(url: URL, method: String, body: RequestBody?, client: OkHttpClient): Response {
private fun sendRequestToUrl(
url: URL,
method: String,
body: RequestBody?,
client: OkHttpClient
): Response {
val request = Request.Builder()
.url(url)
.method(method, body)
Expand All @@ -134,18 +140,18 @@ class Fetcher {
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null,
) {
/*
Fetching model and tokenizer file
1. Extract file name from provided URL
2. If file name contains / it means that the file is local and we should return the path
3. Check if the file has a valid extension
a. For tokenizer, the extension should be .bin
b. For model, the extension should be .pte
4. Check if models directory exists, if not create it
5. Check if the file already exists in the models directory, if yes return the path
6. If the file does not exist, and is a tokenizer, fetch the file
7. If the file is a model, fetch the file with ProgressResponseBody
*/
/*
Fetching model and tokenizer file
1. Extract file name from provided URL
2. If file name contains / it means that the file is local and we should return the path
3. Check if the file has a valid extension
a. For tokenizer, the extension should be .bin
b. For model, the extension should be .pte
4. Check if models directory exists, if not create it
5. Check if the file already exists in the models directory, if yes return the path
6. If the file does not exist, and is a tokenizer, fetch the file
7. If the file is a model, fetch the file with ProgressResponseBody
*/
val fileName: String

try {
Expand All @@ -165,7 +171,7 @@ class Fetcher {
return
}

var tempFile = File(context.filesDir, fileName)
val tempFile = File(context.filesDir, fileName)
if (tempFile.exists()) {
tempFile.delete()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.swmansion.rnexecutorch
package com.swmansion.rnexecutorch.utils

import okhttp3.MediaType
import okhttp3.ResponseBody
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.swmansion.rnexecutorch.utils

import com.facebook.react.bridge.ReadableArray
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor

class TensorUtils {
companion object {
fun getExecutorchInput(input: ReadableArray, shape: LongArray, type: Int): EValue {
try {
when (type) {
0 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createByteArray(input), shape)
return EValue.from(inputTensor)
}

1 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createIntArray(input), shape)
return EValue.from(inputTensor)
}

2 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createLongArray(input), shape)
return EValue.from(inputTensor)
}

3 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createFloatArray(input), shape)
return EValue.from(inputTensor)
}

4 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createDoubleArray(input), shape)
return EValue.from(inputTensor)
}

else -> {
throw IllegalArgumentException("Invalid input type: $type")
}
}
} catch (e: IllegalArgumentException) {
throw e
}
}
}
}
Loading