From dcae3d088521d1ef430b50c9558a1d617375e186 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Tue, 19 Aug 2025 20:43:07 +0000 Subject: [PATCH 01/11] chore(ai): update `Cargo.toml` --- ext/ai/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/ai/Cargo.toml b/ext/ai/Cargo.toml index fb4c2b48d..e1cee5f64 100644 --- a/ext/ai/Cargo.toml +++ b/ext/ai/Cargo.toml @@ -39,5 +39,5 @@ openblas-src = { version = "0.10", features = ['cblas', 'system'] } rand = "0.8" tokenizers = { version = ">=0.13.4", default-features = false, features = ["onig"] } -ort = { version = "=2.0.0-rc.9", default-features = false, features = ["ndarray", "half", "load-dynamic", "cuda"] } -ort-sys = "=2.0.0-rc.9" +ort = { version = "=2.0.0-rc.10", default-features = false, features = ["ndarray", "half", "load-dynamic", "cuda"] } +ort-sys = "=2.0.0-rc.10" From a885622305c171338a28cc70916a037be0cab197 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Tue, 19 Aug 2025 20:43:23 +0000 Subject: [PATCH 02/11] chore(ai): update `Cargo.lock` --- Cargo.lock | 51 +++++++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 659f6af5d..f66af55a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2058,7 +2058,7 @@ dependencies = [ "serde", "serde_json", "serde_v8", - "smallvec", + "smallvec 1.13.1", "sourcemap 8.0.1", "static_assertions", "tokio", @@ -2293,7 +2293,7 @@ dependencies = [ "ring 0.17.14", "scopeguard", "serde", - "smallvec", + "smallvec 1.13.1", "thiserror 2.0.8", "tokio", "tokio-util", @@ -4154,7 +4154,7 @@ dependencies = [ "rand", "resolv-conf", "serde", - "smallvec", + "smallvec 1.13.1", "thiserror 2.0.8", "tokio", "tracing", @@ -4338,7 +4338,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "smallvec", + "smallvec 1.13.1", "tokio", "want", ] @@ -4515,7 +4515,7 @@ dependencies = [ "icu_normalizer_data", "icu_properties", "icu_provider", - "smallvec", + "smallvec 1.13.1", "utf16_iter", "utf8_iter", "write16", @@ -4590,7 +4590,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ "idna_adapter", - "smallvec", + "smallvec 1.13.1", "utf8_iter", ] @@ -5200,7 +5200,7 @@ dependencies = [ "parking_lot", "quanta", "rustc_version 0.4.0", - "smallvec", + "smallvec 1.13.1", "tagptr", "thiserror 1.0.62", "triomphe", @@ -5434,7 +5434,7 @@ dependencies = [ "num-traits", "rand", "serde", - "smallvec", + "smallvec 1.13.1", "zeroize", ] @@ -5730,22 +5730,23 @@ checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" [[package]] name = "ort" -version = "2.0.0-rc.9" +version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52afb44b6b0cffa9bf45e4d37e5a4935b0334a51570658e279e9e3e6cf324aa5" +checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" dependencies = [ "half", "libloading 0.8.1", "ndarray", "ort-sys", + "smallvec 2.0.0-alpha.10", "tracing", ] [[package]] name = "ort-sys" -version = "2.0.0-rc.9" +version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c41d7757331aef2d04b9cb09b45583a59217628beaf91895b7e76187b6e8c088" +checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890" dependencies = [ "pkg-config", ] @@ -5858,7 +5859,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "smallvec", + "smallvec 1.13.1", "windows-targets 0.48.5", ] @@ -6776,7 +6777,7 @@ dependencies = [ "fallible-streaming-iterator", "hashlink 0.9.1", "libsqlite3-sys", - "smallvec", + "smallvec 1.13.1", ] [[package]] @@ -7223,7 +7224,7 @@ source = "git+https://github.com/supabase/deno_core?branch=324-supabase#4f0b7554 dependencies = [ "num-bigint", "serde", - "smallvec", + "smallvec 1.13.1", "thiserror 1.0.62", "v8", ] @@ -7436,6 +7437,12 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +[[package]] +name = "smallvec" +version = "2.0.0-alpha.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b" + [[package]] name = "smartstring" version = "1.0.1" @@ -7831,7 +7838,7 @@ dependencies = [ "num-traits", "phf", "serde", - "smallvec", + "smallvec 1.13.1", "smartstring", "stacker", "swc_atoms", @@ -7854,7 +7861,7 @@ dependencies = [ "phf", "rustc-hash 1.1.0", "serde", - "smallvec", + "smallvec 1.13.1", "swc_atoms", "swc_common", "swc_ecma_ast", @@ -7923,7 +7930,7 @@ dependencies = [ "either", "rustc-hash 1.1.0", "serde", - "smallvec", + "smallvec 1.13.1", "swc_atoms", "swc_common", "swc_ecma_ast", @@ -8644,7 +8651,7 @@ dependencies = [ "serde", "serde_json", "sharded-slab", - "smallvec", + "smallvec 1.13.1", "thread_local", "tracing", "tracing-core", @@ -8781,7 +8788,7 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" dependencies = [ - "smallvec", + "smallvec 1.13.1", ] [[package]] @@ -9150,7 +9157,7 @@ dependencies = [ "ron", "rustc-hash 1.1.0", "serde", - "smallvec", + "smallvec 1.13.1", "thiserror 1.0.62", "web-sys", "wgpu-hal", @@ -9191,7 +9198,7 @@ dependencies = [ "range-alloc", "raw-window-handle", "rustc-hash 1.1.0", - "smallvec", + "smallvec 1.13.1", "thiserror 1.0.62", "wasm-bindgen", "web-sys", From 9e8da1c30b8537b19e1b0d062fd3e3b61d9a6a21 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Wed, 20 Aug 2025 11:04:14 +0000 Subject: [PATCH 03/11] chore(ai): update `ONNXRUNTIME_VERSION` --- .env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.env b/.env index f4c018213..201571e3f 100644 --- a/.env +++ b/.env @@ -1,4 +1,4 @@ -ONNXRUNTIME_VERSION=1.20.1 +ONNXRUNTIME_VERSION=1.22.0 DENO_VERSION=2.1.4 EDGE_RUNTIME_PORT=9998 AI_INFERENCE_API_HOST=http://localhost:11434 From 1059f1b24b33dbcd89afb44aa63adb76b89f4047 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Wed, 20 Aug 2025 11:42:58 +0000 Subject: [PATCH 04/11] stamp: solving ort api breaking changes --- ext/ai/lib.rs | 40 +++++++++++-------- ext/ai/onnxruntime/mod.rs | 10 ++++- ext/ai/onnxruntime/model.rs | 45 +++++++++++++++------ ext/ai/onnxruntime/session.rs | 75 +++++++++++++++++------------------ ext/ai/onnxruntime/tensor.rs | 34 +++++++++++++--- 5 files changed, 128 insertions(+), 76 deletions(-) diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 224b0450f..e6a7f0725 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -21,6 +21,7 @@ use ndarray::ArrayView3; use ndarray::Axis; use ndarray::Ix3; use ort::inputs; +use ort::value::TensorRef; use reqwest::Url; use session::load_session_from_url; use std::cell::RefCell; @@ -180,6 +181,7 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { -> Result, Error> { let encoded_prompt = tokenizer.encode(prompt, true).map_err(anyhow::Error::msg)?; + let input_ids = encoded_prompt .get_ids() .iter() @@ -198,32 +200,36 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { .map(|i| *i as i64) .collect::>(); - let input_ids_array = Array1::from_iter(input_ids.iter().cloned()); - let input_ids_array = input_ids_array.view().insert_axis(Axis(0)); + let input_ids_array = TensorRef::from_array_view(([input_ids.len(), 1], &*input_ids))?; + let attention_mask_array = TensorRef::from_array_view(([1, encoded_prompt.len()], &*attention_mask))?; + let token_type_ids_array = TensorRef::from_array_view(([1, encoded_prompt.len()], &*token_type_ids))?; - let attention_mask_array = - Array1::from_iter(attention_mask.iter().cloned()); - let attention_mask_array = - attention_mask_array.view().insert_axis(Axis(0)); - let token_type_ids_array = - Array1::from_iter(token_type_ids.iter().cloned()); - let token_type_ids_array = - token_type_ids_array.view().insert_axis(Axis(0)); + let Ok(mut guard) = session.lock() else { + let err = anyhow!("failed to lock session"); + error!(reason = ?err); + return Err(err); + }; let outputs = trace_span!("infer_gte").in_scope(|| { - session.run(inputs! { - "input_ids" => input_ids_array, - "token_type_ids" => token_type_ids_array, - "attention_mask" => attention_mask_array, - }?) + guard.run(inputs! { + "input_ids" => input_ids_array, + "token_type_ids" => token_type_ids_array, + "attention_mask" => attention_mask_array, + }) })?; - let embeddings = outputs["last_hidden_state"].try_extract_tensor()?; + let embeddings = outputs["last_hidden_state"].try_extract_array()?; let embeddings = embeddings.into_dimensionality::()?; let result = if do_mean_pooling { - mean_pool(embeddings, attention_mask_array.insert_axis(Axis(2))) + let attention_mask_array_clone= Array1::from_iter(attention_mask.iter().cloned()); + let attention_mask_array_clone= attention_mask_array_clone.view() + .insert_axis(Axis(0)) + .insert_axis(Axis(2)); + + println!("attention_mask: {attention_mask_array_clone:?}"); + mean_pool(embeddings, attention_mask_array_clone) } else { embeddings.into_owned().remove_axis(Axis(0)) }; diff --git a/ext/ai/onnxruntime/mod.rs b/ext/ai/onnxruntime/mod.rs index b1dd136f8..d77abda0d 100644 --- a/ext/ai/onnxruntime/mod.rs +++ b/ext/ai/onnxruntime/mod.rs @@ -10,6 +10,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; use std::sync::Arc; +use std::sync::Mutex; use anyhow::anyhow; use anyhow::Context; @@ -56,7 +57,7 @@ pub async fn op_ai_ort_init_session( let mut state = state.borrow_mut(); let mut sessions = - { state.try_take::>>().unwrap_or_default() }; + { state.try_take::>>>().unwrap_or_default() }; sessions.push(model.get_session()); state.put(sessions); @@ -103,7 +104,12 @@ pub async fn op_ai_ort_run_session( JsRuntime::op_state_from(state) .borrow_mut() .spawn_cpu_accumul_blocking_scope(move || { - let outputs = match model_session.run(input_values) { + let Ok(mut session_guard) = model_session.lock() else { + let _ = tx.send(Err(anyhow!("failed to lock model session"))); + return; + }; + + let outputs = match session_guard.run(input_values) { Ok(v) => v, Err(err) => { let _ = tx.send(Err(anyhow::Error::from(err))); diff --git a/ext/ai/onnxruntime/model.rs b/ext/ai/onnxruntime/model.rs index f3a17e6e4..ba6357234 100644 --- a/ext/ai/onnxruntime/model.rs +++ b/ext/ai/onnxruntime/model.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use std::sync::Mutex; +use anyhow::anyhow; use anyhow::Result; use deno_core::serde_v8::to_v8; use deno_core::ToV8; @@ -27,56 +29,73 @@ impl std::fmt::Display for ModelInfo { #[derive(Debug)] pub struct Model { info: ModelInfo, - session: Arc, + session: Arc>, } impl Model { - fn new(session_with_id: SessionWithId) -> Self { - let input_names = session_with_id - .session + fn new(session_with_id: SessionWithId) -> Result { + let (input_names, output_names) = { + let Ok(session_guard) = session_with_id.session.lock() else { + return Err(anyhow!("Could not lock model session {}", session_with_id.id)); + }; + + let input_names = session_guard .inputs .iter() .map(|input| input.name.clone()) .collect::>(); - let output_names = session_with_id - .session + let output_names = session_guard .outputs .iter() .map(|output| output.name.clone()) .collect::>(); - Self { + (input_names, output_names) + }; + + Ok(Self { info: ModelInfo { id: session_with_id.id, input_names, output_names, }, session: session_with_id.session, - } + }) } pub fn get_info(&self) -> ModelInfo { self.info.clone() } - pub fn get_session(&self) -> Arc { + pub fn get_session(&self) -> Arc> { self.session.clone() } pub async fn from_id(id: &str) -> Option { - get_session(id) + let session = { + get_session(id) .await .map(|it| SessionWithId::from((id.to_string(), it))) - .map(Self::new) + }; + + let Some(session) = session else { + return None; + }; + + Self::new(session).ok() } pub async fn from_url(model_url: Url) -> Result { - load_session_from_url(model_url).await.map(Self::new) + let session = load_session_from_url(model_url).await?; + + Self::new(session) } pub async fn from_bytes(model_bytes: &[u8]) -> Result { - load_session_from_bytes(model_bytes).await.map(Self::new) + let session = load_session_from_bytes(model_bytes).await?; + + Self::new(session) } } diff --git a/ext/ai/onnxruntime/session.rs b/ext/ai/onnxruntime/session.rs index 6205e8550..6bc902af8 100644 --- a/ext/ai/onnxruntime/session.rs +++ b/ext/ai/onnxruntime/session.rs @@ -5,7 +5,8 @@ use reqwest::Url; use std::collections::HashMap; use std::hash::Hasher; use std::sync::Arc; -use tokio::sync::Mutex; +use std::sync::Mutex; +use tokio::sync::Mutex as AsyncMutex; use tokio_util::compat::FuturesAsyncWriteCompatExt; use tracing::debug; use tracing::instrument; @@ -25,17 +26,17 @@ use ort::session::Session; use crate::onnx::ensure_onnx_env_init; -static SESSIONS: Lazy>>> = - Lazy::new(|| Mutex::new(HashMap::new())); +static SESSIONS: Lazy>>>> = + Lazy::new(|| AsyncMutex::new(HashMap::new())); #[derive(Debug)] pub struct SessionWithId { pub(crate) id: String, - pub(crate) session: Arc, + pub(crate) session: Arc>, } -impl From<(String, Arc)> for SessionWithId { - fn from(value: (String, Arc)) -> Self { +impl From<(String, Arc>)> for SessionWithId { + fn from(value: (String, Arc>)) -> Self { Self { id: value.0, session: value.1, @@ -50,7 +51,7 @@ impl std::fmt::Display for SessionWithId { } impl SessionWithId { - pub fn into_split(self) -> (String, Arc) { + pub fn into_split(self) -> (String, Arc>) { (self.id, self.session) } } @@ -74,53 +75,51 @@ pub(crate) fn get_session_builder() -> Result { Ok(builder) } -fn cpu_execution_provider( -) -> Box> { - Box::new( - [ - // NOTE(Nyannacha): See the comment above. This makes `enable_cpu_mem_arena` set to - // False. - // - // Backgrounds: - // [1]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#9-18 - // [2]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#46-50 - CPUExecutionProvider::default().build(), - ] - .into_iter(), - ) +fn cpu_execution_provider() -> ExecutionProviderDispatch { + // NOTE(Nyannacha): See the comment above. This makes `enable_cpu_mem_arena` set to + // False. + // + // Backgrounds: + // [1]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#9-18 + // [2]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#46-50 + CPUExecutionProvider::default().build() } -fn cuda_execution_provider( -) -> Box> { +fn cuda_execution_provider() -> Option { let cuda = CUDAExecutionProvider::default(); - let providers = match cuda.is_available() { - Ok(is_cuda_available) => { - debug!(cuda_support = is_cuda_available); - if is_cuda_available { - vec![cuda.build()] - } else { - vec![] - } - } + let is_cuda_available = cuda.is_available().is_ok_and(|v| v); + debug!(cuda_support = is_cuda_available); + + if is_cuda_available { + Some(cuda.build()) + }else{ + None + } +} + +fn get_execution_providers( +) -> Vec { + let cpu = cpu_execution_provider(); - _ => vec![], + if let Some(cuda) = cuda_execution_provider() { + return [cuda, cpu].to_vec(); }; - Box::new(providers.into_iter().chain(cpu_execution_provider())) + [cpu].to_vec() } -fn create_session(model_bytes: &[u8]) -> Result, Error> { +fn create_session(model_bytes: &[u8]) -> Result>, Error> { let session = { if let Some(err) = ensure_onnx_env_init() { return Err(anyhow!("failed to create onnx environment: {err}")); } get_session_builder()? - .with_execution_providers(cuda_execution_provider())? + .with_execution_providers(get_execution_providers())? .commit_from_memory(model_bytes)? }; - Ok(Arc::new(session)) + Ok(Arc::new(Mutex::new(session))) } #[instrument(level = "debug", skip_all, fields(model_bytes = model_bytes.len()), err)] @@ -181,7 +180,7 @@ pub(crate) async fn load_session_from_url( Ok((session_id, session).into()) } -pub(crate) async fn get_session(id: &str) -> Option> { +pub(crate) async fn get_session(id: &str) -> Option>> { SESSIONS.lock().await.get(id).cloned() } diff --git a/ext/ai/onnxruntime/tensor.rs b/ext/ai/onnxruntime/tensor.rs index 86b2e7ae8..d72429584 100644 --- a/ext/ai/onnxruntime/tensor.rs +++ b/ext/ai/onnxruntime/tensor.rs @@ -14,6 +14,7 @@ use ort::memory::MemoryInfo; use ort::memory::MemoryType; use ort::session::SessionInputValue; use ort::tensor::PrimitiveTensorElementType; +use ort::tensor::Shape; use ort::tensor::TensorElementType; use ort::value::DynValue; use ort::value::DynValueTypeMarker; @@ -32,8 +33,7 @@ macro_rules! v8_slice_from { (tensor::<$type:ident>($tensor:expr)) => {{ // We must ensure there's some detection to avoid `null pointer` errors // https://github.com/pykeio/ort/issues/185 - let n_detections = $tensor.shape()?[0]; - if n_detections == 0 { + if $tensor.shape().is_empty() { let buf_store = v8::ArrayBuffer::new_backing_store_from_vec(vec![]).make_shared(); let buffer_slice = unsafe { @@ -43,7 +43,7 @@ macro_rules! v8_slice_from { buffer_slice } else { let (_, raw_tensor) = $tensor - .try_extract_raw_tensor_mut::<$type>() + .try_extract_tensor_mut::<$type>() .map_err(AnyError::from)?; let tensor_ptr = raw_tensor.as_ptr(); @@ -113,6 +113,24 @@ pub enum JsTensorType { Uint64, /// Brain 16-bit floating point number, equivalent to [`half::bf16`] (requires the `half` feature). Bfloat16, + Complex64, + Complex128, + /// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values and no infinite + /// values. + Float8E4M3FN, + /// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values, no infinite + /// values, and no negative zero. + Float8E4M3FNUZ, + /// 8-bit floating point number with 5 exponent bits and 2 mantissa bits. + Float8E5M2, + /// 8-bit floating point number with 5 exponent bits and 2 mantissa bits, with only NaN values, no infinite + /// values, and no negative zero. + Float8E5M2FNUZ, + /// 4-bit unsigned integer. + Uint4, + /// 4-bit signed integer. + Int4, + Undefined } #[derive(Serialize, Deserialize)] @@ -183,7 +201,7 @@ impl JsTensor { TensorRefMut::::from_raw( memory_info, data.as_mut_ptr() as *mut c_void, - self.dims, + Shape::new(self.dims), ) }?; @@ -205,7 +223,7 @@ impl JsTensor { )); }; - Tensor::from_string_array((self.dims, data))?.into() + Tensor::from_string_array((self.dims, data.as_slice()))?.into() } TensorElementType::Int8 => self.extract_ort_tensor_ref::()?.into(), TensorElementType::Uint8 => self.extract_ort_tensor_ref::()?.into(), @@ -222,6 +240,9 @@ impl JsTensor { TensorElementType::Bfloat16 => { return Err(anyhow!("'half::bf16' is not supported by JS tensor.")) } + other => { + return Err(anyhow!("'{other:?}' is not supported by JS tensor.")) + } }; Ok(input_value) @@ -243,7 +264,7 @@ impl ToJsTensor { "JS only support 'ort::Value' of 'Tensor' type, got '{value:?}'." )); }; - let tensor_shape = value.shape()?; + let tensor_shape = value.shape().to_vec(); let buffer_slice = match tensor_type { TensorElementType::Float32 => v8_slice_from!(tensor::(value)), @@ -260,6 +281,7 @@ impl ToJsTensor { TensorElementType::String => todo!(), TensorElementType::Float16 => todo!(), TensorElementType::Bfloat16 => todo!(), + _ => todo!() }; Ok(Self { From 98eb8768ba2370e86c7a15876bf82c394d65b4c8 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Wed, 20 Aug 2025 12:45:02 +0000 Subject: [PATCH 05/11] revert: ensure N detections to avoid `null pointer` errors --- ext/ai/onnxruntime/tensor.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/ai/onnxruntime/tensor.rs b/ext/ai/onnxruntime/tensor.rs index d72429584..4b51eac00 100644 --- a/ext/ai/onnxruntime/tensor.rs +++ b/ext/ai/onnxruntime/tensor.rs @@ -33,7 +33,8 @@ macro_rules! v8_slice_from { (tensor::<$type:ident>($tensor:expr)) => {{ // We must ensure there's some detection to avoid `null pointer` errors // https://github.com/pykeio/ort/issues/185 - if $tensor.shape().is_empty() { + let n_detections = $tensor.shape()[0]; + if n_detections == 0 { let buf_store = v8::ArrayBuffer::new_backing_store_from_vec(vec![]).make_shared(); let buffer_slice = unsafe { From 15a4aded357d3665179af49d3a617aceabc5822a Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Thu, 21 Aug 2025 09:25:54 +0000 Subject: [PATCH 06/11] stamp: clippy & fmt --- ext/ai/lib.rs | 25 +++++++++++++++------- ext/ai/onnxruntime/mod.rs | 7 +++++-- ext/ai/onnxruntime/model.rs | 39 ++++++++++++++++++----------------- ext/ai/onnxruntime/session.rs | 19 ++++++++--------- ext/ai/onnxruntime/tensor.rs | 36 ++++++++++++++++---------------- 5 files changed, 69 insertions(+), 57 deletions(-) diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index e6a7f0725..e108c24b3 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -200,10 +200,17 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { .map(|i| *i as i64) .collect::>(); - let input_ids_array = TensorRef::from_array_view(([input_ids.len(), 1], &*input_ids))?; - let attention_mask_array = TensorRef::from_array_view(([1, encoded_prompt.len()], &*attention_mask))?; - let token_type_ids_array = TensorRef::from_array_view(([1, encoded_prompt.len()], &*token_type_ids))?; - + let input_ids_array = + TensorRef::from_array_view(([input_ids.len(), 1], &*input_ids))?; + let attention_mask_array = TensorRef::from_array_view(( + [1, encoded_prompt.len()], + &*attention_mask, + ))?; + + let token_type_ids_array = TensorRef::from_array_view(( + [1, encoded_prompt.len()], + &*token_type_ids, + ))?; let Ok(mut guard) = session.lock() else { let err = anyhow!("failed to lock session"); @@ -223,10 +230,12 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { let embeddings = embeddings.into_dimensionality::()?; let result = if do_mean_pooling { - let attention_mask_array_clone= Array1::from_iter(attention_mask.iter().cloned()); - let attention_mask_array_clone= attention_mask_array_clone.view() - .insert_axis(Axis(0)) - .insert_axis(Axis(2)); + let attention_mask_array_clone = + Array1::from_iter(attention_mask.iter().cloned()); + let attention_mask_array_clone = attention_mask_array_clone + .view() + .insert_axis(Axis(0)) + .insert_axis(Axis(2)); println!("attention_mask: {attention_mask_array_clone:?}"); mean_pool(embeddings, attention_mask_array_clone) diff --git a/ext/ai/onnxruntime/mod.rs b/ext/ai/onnxruntime/mod.rs index d77abda0d..0eda4792e 100644 --- a/ext/ai/onnxruntime/mod.rs +++ b/ext/ai/onnxruntime/mod.rs @@ -56,8 +56,11 @@ pub async fn op_ai_ort_init_session( }; let mut state = state.borrow_mut(); - let mut sessions = - { state.try_take::>>>().unwrap_or_default() }; + let mut sessions = { + state + .try_take::>>>() + .unwrap_or_default() + }; sessions.push(model.get_session()); state.put(sessions); diff --git a/ext/ai/onnxruntime/model.rs b/ext/ai/onnxruntime/model.rs index ba6357234..0f8925845 100644 --- a/ext/ai/onnxruntime/model.rs +++ b/ext/ai/onnxruntime/model.rs @@ -36,20 +36,23 @@ impl Model { fn new(session_with_id: SessionWithId) -> Result { let (input_names, output_names) = { let Ok(session_guard) = session_with_id.session.lock() else { - return Err(anyhow!("Could not lock model session {}", session_with_id.id)); - }; - - let input_names = session_guard - .inputs - .iter() - .map(|input| input.name.clone()) - .collect::>(); - - let output_names = session_guard - .outputs - .iter() - .map(|output| output.name.clone()) - .collect::>(); + return Err(anyhow!( + "Could not lock model session {}", + session_with_id.id + )); + }; + + let input_names = session_guard + .inputs + .iter() + .map(|input| input.name.clone()) + .collect::>(); + + let output_names = session_guard + .outputs + .iter() + .map(|output| output.name.clone()) + .collect::>(); (input_names, output_names) }; @@ -75,13 +78,11 @@ impl Model { pub async fn from_id(id: &str) -> Option { let session = { get_session(id) - .await - .map(|it| SessionWithId::from((id.to_string(), it))) + .await + .map(|it| SessionWithId::from((id.to_string(), it))) }; - let Some(session) = session else { - return None; - }; + let session = session?; Self::new(session).ok() } diff --git a/ext/ai/onnxruntime/session.rs b/ext/ai/onnxruntime/session.rs index 6bc902af8..8bfb74b32 100644 --- a/ext/ai/onnxruntime/session.rs +++ b/ext/ai/onnxruntime/session.rs @@ -76,13 +76,13 @@ pub(crate) fn get_session_builder() -> Result { } fn cpu_execution_provider() -> ExecutionProviderDispatch { - // NOTE(Nyannacha): See the comment above. This makes `enable_cpu_mem_arena` set to - // False. - // - // Backgrounds: - // [1]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#9-18 - // [2]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#46-50 - CPUExecutionProvider::default().build() + // NOTE(Nyannacha): See the comment above. This makes `enable_cpu_mem_arena` set to + // False. + // + // Backgrounds: + // [1]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#9-18 + // [2]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#46-50 + CPUExecutionProvider::default().build() } fn cuda_execution_provider() -> Option { @@ -92,13 +92,12 @@ fn cuda_execution_provider() -> Option { if is_cuda_available { Some(cuda.build()) - }else{ + } else { None } } -fn get_execution_providers( -) -> Vec { +fn get_execution_providers() -> Vec { let cpu = cpu_execution_provider(); if let Some(cuda) = cuda_execution_provider() { diff --git a/ext/ai/onnxruntime/tensor.rs b/ext/ai/onnxruntime/tensor.rs index 4b51eac00..cf805b8ea 100644 --- a/ext/ai/onnxruntime/tensor.rs +++ b/ext/ai/onnxruntime/tensor.rs @@ -115,23 +115,23 @@ pub enum JsTensorType { /// Brain 16-bit floating point number, equivalent to [`half::bf16`] (requires the `half` feature). Bfloat16, Complex64, - Complex128, - /// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values and no infinite - /// values. - Float8E4M3FN, - /// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values, no infinite - /// values, and no negative zero. - Float8E4M3FNUZ, - /// 8-bit floating point number with 5 exponent bits and 2 mantissa bits. - Float8E5M2, - /// 8-bit floating point number with 5 exponent bits and 2 mantissa bits, with only NaN values, no infinite - /// values, and no negative zero. - Float8E5M2FNUZ, - /// 4-bit unsigned integer. - Uint4, - /// 4-bit signed integer. - Int4, - Undefined + Complex128, + /// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values and no infinite + /// values. + Float8E4M3FN, + /// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values, no infinite + /// values, and no negative zero. + Float8E4M3FNUZ, + /// 8-bit floating point number with 5 exponent bits and 2 mantissa bits. + Float8E5M2, + /// 8-bit floating point number with 5 exponent bits and 2 mantissa bits, with only NaN values, no infinite + /// values, and no negative zero. + Float8E5M2FNUZ, + /// 4-bit unsigned integer. + Uint4, + /// 4-bit signed integer. + Int4, + Undefined, } #[derive(Serialize, Deserialize)] @@ -282,7 +282,7 @@ impl ToJsTensor { TensorElementType::String => todo!(), TensorElementType::Float16 => todo!(), TensorElementType::Bfloat16 => todo!(), - _ => todo!() + _ => todo!(), }; Ok(Self { From 5bb52dfaa28b8890116fd93b9559f16b30b51ef4 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Tue, 9 Sep 2025 10:02:02 +0000 Subject: [PATCH 07/11] fix: input ids array with wrong shape --- ext/ai/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index e108c24b3..c694c1d92 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -200,8 +200,9 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { .map(|i| *i as i64) .collect::>(); + // Convert our flattened arrays into 2-dimensional tensors of shape [N, L] -> Since we're not batching 'N' will be always = 1 let input_ids_array = - TensorRef::from_array_view(([input_ids.len(), 1], &*input_ids))?; + TensorRef::from_array_view(([1, input_ids.len()], &*input_ids))?; let attention_mask_array = TensorRef::from_array_view(( [1, encoded_prompt.len()], &*attention_mask, @@ -237,7 +238,6 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { .insert_axis(Axis(0)) .insert_axis(Axis(2)); - println!("attention_mask: {attention_mask_array_clone:?}"); mean_pool(embeddings, attention_mask_array_clone) } else { embeddings.into_owned().remove_axis(Axis(0)) From ef255cefb94ae198de9338bd74a69a85de3b5a44 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Wed, 10 Sep 2025 09:01:02 +0000 Subject: [PATCH 08/11] stamp: removing `todo!()` from Tensor conversions --- ext/ai/onnxruntime/tensor.rs | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/ext/ai/onnxruntime/tensor.rs b/ext/ai/onnxruntime/tensor.rs index cf805b8ea..e18f9d1ea 100644 --- a/ext/ai/onnxruntime/tensor.rs +++ b/ext/ai/onnxruntime/tensor.rs @@ -235,12 +235,6 @@ impl JsTensor { TensorElementType::Int64 => self.extract_ort_tensor_ref::()?.into(), TensorElementType::Uint64 => self.extract_ort_tensor_ref::()?.into(), TensorElementType::Bool => self.extract_ort_tensor_ref::()?.into(), - TensorElementType::Float16 => { - return Err(anyhow!("'half::f16' is not supported by JS tensor.")) - } - TensorElementType::Bfloat16 => { - return Err(anyhow!("'half::bf16' is not supported by JS tensor.")) - } other => { return Err(anyhow!("'{other:?}' is not supported by JS tensor.")) } @@ -279,10 +273,9 @@ impl ToJsTensor { TensorElementType::Int64 => v8_slice_from!(tensor::(value)), TensorElementType::Uint64 => v8_slice_from!(tensor::(value)), TensorElementType::Bool => v8_slice_from!(tensor::(value)), - TensorElementType::String => todo!(), - TensorElementType::Float16 => todo!(), - TensorElementType::Bfloat16 => todo!(), - _ => todo!(), + other => { + return Err(anyhow!("'{other:?}' is not supported by JS tensor.")) + }, }; Ok(Self { From 83918cea148bb91d886c48e01ab0ce303343d279 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Wed, 10 Sep 2025 09:04:41 +0000 Subject: [PATCH 09/11] stamp: format --- ext/ai/onnxruntime/tensor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ai/onnxruntime/tensor.rs b/ext/ai/onnxruntime/tensor.rs index e18f9d1ea..a3758a574 100644 --- a/ext/ai/onnxruntime/tensor.rs +++ b/ext/ai/onnxruntime/tensor.rs @@ -275,7 +275,7 @@ impl ToJsTensor { TensorElementType::Bool => v8_slice_from!(tensor::(value)), other => { return Err(anyhow!("'{other:?}' is not supported by JS tensor.")) - }, + } }; Ok(Self { From 583b6dba8b826c0d86c0bcbea0ad30ec15c912f8 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Mon, 15 Sep 2025 19:35:13 +0000 Subject: [PATCH 10/11] feat: using dashmap to store sessions --- Cargo.lock | 1 + ext/ai/Cargo.toml | 1 + ext/ai/onnxruntime/session.rs | 28 ++++++++++++---------------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f66af55a2..7fb0b1d6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3153,6 +3153,7 @@ dependencies = [ "clap", "convert_case", "ctor", + "dashmap", "deno_core", "ext_ai_v8_utilities", "faster-hex", diff --git a/ext/ai/Cargo.toml b/ext/ai/Cargo.toml index e1cee5f64..444057b42 100644 --- a/ext/ai/Cargo.toml +++ b/ext/ai/Cargo.toml @@ -17,6 +17,7 @@ ext_ai_v8_utilities.workspace = true anyhow.workspace = true clap = { workspace = true, features = ["derive"] } ctor.workspace = true +dashmap.workspace = true faster-hex.workspace = true futures.workspace = true futures-util = { workspace = true, features = ["io"] } diff --git a/ext/ai/onnxruntime/session.rs b/ext/ai/onnxruntime/session.rs index 8bfb74b32..0d8478eaa 100644 --- a/ext/ai/onnxruntime/session.rs +++ b/ext/ai/onnxruntime/session.rs @@ -1,12 +1,11 @@ +use dashmap::DashMap; use deno_core::error::AnyError; use futures::io::AllowStdIo; use once_cell::sync::Lazy; use reqwest::Url; -use std::collections::HashMap; use std::hash::Hasher; use std::sync::Arc; use std::sync::Mutex; -use tokio::sync::Mutex as AsyncMutex; use tokio_util::compat::FuturesAsyncWriteCompatExt; use tracing::debug; use tracing::instrument; @@ -26,8 +25,8 @@ use ort::session::Session; use crate::onnx::ensure_onnx_env_init; -static SESSIONS: Lazy>>>> = - Lazy::new(|| AsyncMutex::new(HashMap::new())); +static SESSIONS: Lazy>>> = + Lazy::new(DashMap::new); #[derive(Debug)] pub struct SessionWithId { @@ -136,16 +135,14 @@ pub(crate) async fn load_session_from_bytes( faster_hex::hex_string(&hasher.finish().to_be_bytes()) }; - let mut sessions = SESSIONS.lock().await; - - if let Some(session) = sessions.get(&session_id) { + if let Some(session) = SESSIONS.get(&session_id) { return Ok((session_id, session.clone()).into()); } trace!(session_id, "new session"); let session = create_session(model_bytes)?; - sessions.insert(session_id.clone(), session.clone()); + SESSIONS.insert(session_id.clone(), session.clone()); Ok((session_id, session).into()) } @@ -156,9 +153,7 @@ pub(crate) async fn load_session_from_url( ) -> Result { let session_id = fxhash::hash(model_url.as_str()).to_string(); - let mut sessions = SESSIONS.lock().await; - - if let Some(session) = sessions.get(&session_id) { + if let Some(session) = SESSIONS.get(&session_id) { debug!(session_id, "use existing session"); return Ok((session_id, session.clone()).into()); } @@ -174,22 +169,23 @@ pub(crate) async fn load_session_from_url( let session = create_session(model_bytes.as_slice())?; debug!(session_id, "new session"); - sessions.insert(session_id.clone(), session.clone()); + SESSIONS.insert(session_id.clone(), session.clone()); Ok((session_id, session).into()) } pub(crate) async fn get_session(id: &str) -> Option>> { - SESSIONS.lock().await.get(id).cloned() + SESSIONS.get(id).map(|value| value.pair().1.clone()) } pub async fn cleanup() -> Result { let mut remove_counter = 0; { - let mut guard = SESSIONS.lock().await; + //let mut guard = SESSIONS.lock().await; let mut to_be_removed = vec![]; - for (key, session) in &mut *guard { + for v in SESSIONS.iter() { + let (key, session) = v.pair(); // Since we're currently referencing the session at this point // It also will increments the counter, so we need to check: counter > 1 if Arc::strong_count(session) > 1 { @@ -200,7 +196,7 @@ pub async fn cleanup() -> Result { } for key in to_be_removed { - let old_store = guard.remove(&key); + let old_store = SESSIONS.remove(&key); debug_assert!(old_store.is_some()); remove_counter += 1; From 5af8c7121473ba44897b45c69f682915628f8ac6 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Thu, 18 Sep 2025 18:55:22 +0000 Subject: [PATCH 11/11] revert: using transmute to ignore mut refs of Sessions instead of Mutex lock --- ext/ai/lib.rs | 10 ++++------ ext/ai/onnxruntime/mod.rs | 17 ++++++----------- ext/ai/onnxruntime/model.rs | 22 +++++++--------------- ext/ai/onnxruntime/session.rs | 25 +++++++++++++++---------- 4 files changed, 32 insertions(+), 42 deletions(-) diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index c694c1d92..d2d73b4ae 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -40,6 +40,8 @@ use tracing::error; use tracing::trace_span; use tracing::Instrument; +use crate::onnxruntime::session::as_mut_session; + deno_core::extension!( ai, ops = [ @@ -213,14 +215,10 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { &*token_type_ids, ))?; - let Ok(mut guard) = session.lock() else { - let err = anyhow!("failed to lock session"); - error!(reason = ?err); - return Err(err); - }; + let session = unsafe { as_mut_session(&session) }; let outputs = trace_span!("infer_gte").in_scope(|| { - guard.run(inputs! { + session.run(inputs! { "input_ids" => input_ids_array, "token_type_ids" => token_type_ids_array, "attention_mask" => attention_mask_array, diff --git a/ext/ai/onnxruntime/mod.rs b/ext/ai/onnxruntime/mod.rs index 0eda4792e..c591c8f21 100644 --- a/ext/ai/onnxruntime/mod.rs +++ b/ext/ai/onnxruntime/mod.rs @@ -10,7 +10,6 @@ use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; use std::sync::Arc; -use std::sync::Mutex; use anyhow::anyhow; use anyhow::Context; @@ -33,6 +32,8 @@ use tokio::sync::oneshot; use tracing::debug; use tracing::trace; +use crate::onnxruntime::session::as_mut_session; + #[op2(async)] #[to_v8] pub async fn op_ai_ort_init_session( @@ -56,11 +57,8 @@ pub async fn op_ai_ort_init_session( }; let mut state = state.borrow_mut(); - let mut sessions = { - state - .try_take::>>>() - .unwrap_or_default() - }; + let mut sessions = + { state.try_take::>>().unwrap_or_default() }; sessions.push(model.get_session()); state.put(sessions); @@ -107,12 +105,9 @@ pub async fn op_ai_ort_run_session( JsRuntime::op_state_from(state) .borrow_mut() .spawn_cpu_accumul_blocking_scope(move || { - let Ok(mut session_guard) = model_session.lock() else { - let _ = tx.send(Err(anyhow!("failed to lock model session"))); - return; - }; + let session = unsafe { as_mut_session(&model_session) }; - let outputs = match session_guard.run(input_values) { + let outputs = match session.run(input_values) { Ok(v) => v, Err(err) => { let _ = tx.send(Err(anyhow::Error::from(err))); diff --git a/ext/ai/onnxruntime/model.rs b/ext/ai/onnxruntime/model.rs index 0f8925845..475c2a4ef 100644 --- a/ext/ai/onnxruntime/model.rs +++ b/ext/ai/onnxruntime/model.rs @@ -1,12 +1,9 @@ -use std::sync::Arc; -use std::sync::Mutex; - -use anyhow::anyhow; use anyhow::Result; use deno_core::serde_v8::to_v8; use deno_core::ToV8; use ort::session::Session; use reqwest::Url; +use std::sync::Arc; use super::session::get_session; use super::session::load_session_from_bytes; @@ -29,26 +26,21 @@ impl std::fmt::Display for ModelInfo { #[derive(Debug)] pub struct Model { info: ModelInfo, - session: Arc>, + session: Arc, } impl Model { fn new(session_with_id: SessionWithId) -> Result { let (input_names, output_names) = { - let Ok(session_guard) = session_with_id.session.lock() else { - return Err(anyhow!( - "Could not lock model session {}", - session_with_id.id - )); - }; - - let input_names = session_guard + let session = { session_with_id.session.clone() }; + + let input_names = session .inputs .iter() .map(|input| input.name.clone()) .collect::>(); - let output_names = session_guard + let output_names = session .outputs .iter() .map(|output| output.name.clone()) @@ -71,7 +63,7 @@ impl Model { self.info.clone() } - pub fn get_session(&self) -> Arc> { + pub fn get_session(&self) -> Arc { self.session.clone() } diff --git a/ext/ai/onnxruntime/session.rs b/ext/ai/onnxruntime/session.rs index 0d8478eaa..dd0d009b9 100644 --- a/ext/ai/onnxruntime/session.rs +++ b/ext/ai/onnxruntime/session.rs @@ -5,7 +5,6 @@ use once_cell::sync::Lazy; use reqwest::Url; use std::hash::Hasher; use std::sync::Arc; -use std::sync::Mutex; use tokio_util::compat::FuturesAsyncWriteCompatExt; use tracing::debug; use tracing::instrument; @@ -25,17 +24,16 @@ use ort::session::Session; use crate::onnx::ensure_onnx_env_init; -static SESSIONS: Lazy>>> = - Lazy::new(DashMap::new); +static SESSIONS: Lazy>> = Lazy::new(DashMap::new); #[derive(Debug)] pub struct SessionWithId { pub(crate) id: String, - pub(crate) session: Arc>, + pub(crate) session: Arc, } -impl From<(String, Arc>)> for SessionWithId { - fn from(value: (String, Arc>)) -> Self { +impl From<(String, Arc)> for SessionWithId { + fn from(value: (String, Arc)) -> Self { Self { id: value.0, session: value.1, @@ -50,7 +48,7 @@ impl std::fmt::Display for SessionWithId { } impl SessionWithId { - pub fn into_split(self) -> (String, Arc>) { + pub fn into_split(self) -> (String, Arc) { (self.id, self.session) } } @@ -106,7 +104,7 @@ fn get_execution_providers() -> Vec { [cpu].to_vec() } -fn create_session(model_bytes: &[u8]) -> Result>, Error> { +fn create_session(model_bytes: &[u8]) -> Result, Error> { let session = { if let Some(err) = ensure_onnx_env_init() { return Err(anyhow!("failed to create onnx environment: {err}")); @@ -117,7 +115,14 @@ fn create_session(model_bytes: &[u8]) -> Result>, Error> { .commit_from_memory(model_bytes)? }; - Ok(Arc::new(Mutex::new(session))) + Ok(Arc::new(session)) +} + +#[allow(mutable_transmutes)] +#[allow(clippy::mut_from_ref)] +pub(crate) unsafe fn as_mut_session(session: &Arc) -> &mut Session { + // SAFETY: CPU EP https://github.com/pykeio/ort/issues/402#issuecomment-2949993914 + unsafe { std::mem::transmute::<&Session, &mut Session>(&session.clone()) } } #[instrument(level = "debug", skip_all, fields(model_bytes = model_bytes.len()), err)] @@ -174,7 +179,7 @@ pub(crate) async fn load_session_from_url( Ok((session_id, session).into()) } -pub(crate) async fn get_session(id: &str) -> Option>> { +pub(crate) async fn get_session(id: &str) -> Option> { SESSIONS.get(id).map(|value| value.pair().1.clone()) }