Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ONNXRUNTIME_VERSION=1.20.1
ONNXRUNTIME_VERSION=1.22.0
DENO_VERSION=2.1.4
EDGE_RUNTIME_PORT=9998
AI_INFERENCE_API_HOST=http://localhost:11434
52 changes: 30 additions & 22 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions ext/ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ext_ai_v8_utilities.workspace = true
anyhow.workspace = true
clap = { workspace = true, features = ["derive"] }
ctor.workspace = true
dashmap.workspace = true
faster-hex.workspace = true
futures.workspace = true
futures-util = { workspace = true, features = ["io"] }
Expand All @@ -39,5 +40,5 @@ openblas-src = { version = "0.10", features = ['cblas', 'system'] }
rand = "0.8"
tokenizers = { version = ">=0.13.4", default-features = false, features = ["onig"] }

ort = { version = "=2.0.0-rc.9", default-features = false, features = ["ndarray", "half", "load-dynamic", "cuda"] }
ort-sys = "=2.0.0-rc.9"
ort = { version = "=2.0.0-rc.10", default-features = false, features = ["ndarray", "half", "load-dynamic", "cuda"] }
ort-sys = "=2.0.0-rc.10"
45 changes: 29 additions & 16 deletions ext/ai/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use ndarray::ArrayView3;
use ndarray::Axis;
use ndarray::Ix3;
use ort::inputs;
use ort::value::TensorRef;
use reqwest::Url;
use session::load_session_from_url;
use std::cell::RefCell;
Expand All @@ -39,6 +40,8 @@ use tracing::error;
use tracing::trace_span;
use tracing::Instrument;

use crate::onnxruntime::session::as_mut_session;

deno_core::extension!(
ai,
ops = [
Expand Down Expand Up @@ -180,6 +183,7 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
-> Result<Vec<f32>, Error> {
let encoded_prompt =
tokenizer.encode(prompt, true).map_err(anyhow::Error::msg)?;

let input_ids = encoded_prompt
.get_ids()
.iter()
Expand All @@ -198,32 +202,41 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
.map(|i| *i as i64)
.collect::<Vec<_>>();

let input_ids_array = Array1::from_iter(input_ids.iter().cloned());
let input_ids_array = input_ids_array.view().insert_axis(Axis(0));
// Convert our flattened arrays into 2-dimensional tensors of shape [N, L] -> Since we're not batching 'N' will be always = 1
let input_ids_array =
TensorRef::from_array_view(([1, input_ids.len()], &*input_ids))?;
let attention_mask_array = TensorRef::from_array_view((
[1, encoded_prompt.len()],
&*attention_mask,
))?;

let attention_mask_array =
Array1::from_iter(attention_mask.iter().cloned());
let attention_mask_array =
attention_mask_array.view().insert_axis(Axis(0));
let token_type_ids_array = TensorRef::from_array_view((
[1, encoded_prompt.len()],
&*token_type_ids,
))?;

let token_type_ids_array =
Array1::from_iter(token_type_ids.iter().cloned());
let token_type_ids_array =
token_type_ids_array.view().insert_axis(Axis(0));
let session = unsafe { as_mut_session(&session) };

let outputs = trace_span!("infer_gte").in_scope(|| {
session.run(inputs! {
"input_ids" => input_ids_array,
"token_type_ids" => token_type_ids_array,
"attention_mask" => attention_mask_array,
}?)
"input_ids" => input_ids_array,
"token_type_ids" => token_type_ids_array,
"attention_mask" => attention_mask_array,
})
})?;

let embeddings = outputs["last_hidden_state"].try_extract_tensor()?;
let embeddings = outputs["last_hidden_state"].try_extract_array()?;
let embeddings = embeddings.into_dimensionality::<Ix3>()?;

let result = if do_mean_pooling {
mean_pool(embeddings, attention_mask_array.insert_axis(Axis(2)))
let attention_mask_array_clone =
Array1::from_iter(attention_mask.iter().cloned());
let attention_mask_array_clone = attention_mask_array_clone
.view()
.insert_axis(Axis(0))
.insert_axis(Axis(2));

mean_pool(embeddings, attention_mask_array_clone)
} else {
embeddings.into_owned().remove_axis(Axis(0))
};
Expand Down
6 changes: 5 additions & 1 deletion ext/ai/onnxruntime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ use tokio::sync::oneshot;
use tracing::debug;
use tracing::trace;

use crate::onnxruntime::session::as_mut_session;

#[op2(async)]
#[to_v8]
pub async fn op_ai_ort_init_session(
Expand Down Expand Up @@ -103,7 +105,9 @@ pub async fn op_ai_ort_run_session(
JsRuntime::op_state_from(state)
.borrow_mut()
.spawn_cpu_accumul_blocking_scope(move || {
let outputs = match model_session.run(input_values) {
let session = unsafe { as_mut_session(&model_session) };

let outputs = match session.run(input_values) {
Ok(v) => v,
Err(err) => {
let _ = tx.send(Err(anyhow::Error::from(err)));
Expand Down
Loading