diff --git a/.env b/.env index f4c018213..201571e3f 100644 --- a/.env +++ b/.env @@ -1,4 +1,4 @@ -ONNXRUNTIME_VERSION=1.20.1 +ONNXRUNTIME_VERSION=1.22.0 DENO_VERSION=2.1.4 EDGE_RUNTIME_PORT=9998 AI_INFERENCE_API_HOST=http://localhost:11434 diff --git a/Cargo.lock b/Cargo.lock index 659f6af5d..7fb0b1d6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2058,7 +2058,7 @@ dependencies = [ "serde", "serde_json", "serde_v8", - "smallvec", + "smallvec 1.13.1", "sourcemap 8.0.1", "static_assertions", "tokio", @@ -2293,7 +2293,7 @@ dependencies = [ "ring 0.17.14", "scopeguard", "serde", - "smallvec", + "smallvec 1.13.1", "thiserror 2.0.8", "tokio", "tokio-util", @@ -3153,6 +3153,7 @@ dependencies = [ "clap", "convert_case", "ctor", + "dashmap", "deno_core", "ext_ai_v8_utilities", "faster-hex", @@ -4154,7 +4155,7 @@ dependencies = [ "rand", "resolv-conf", "serde", - "smallvec", + "smallvec 1.13.1", "thiserror 2.0.8", "tokio", "tracing", @@ -4338,7 +4339,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "smallvec", + "smallvec 1.13.1", "tokio", "want", ] @@ -4515,7 +4516,7 @@ dependencies = [ "icu_normalizer_data", "icu_properties", "icu_provider", - "smallvec", + "smallvec 1.13.1", "utf16_iter", "utf8_iter", "write16", @@ -4590,7 +4591,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ "idna_adapter", - "smallvec", + "smallvec 1.13.1", "utf8_iter", ] @@ -5200,7 +5201,7 @@ dependencies = [ "parking_lot", "quanta", "rustc_version 0.4.0", - "smallvec", + "smallvec 1.13.1", "tagptr", "thiserror 1.0.62", "triomphe", @@ -5434,7 +5435,7 @@ dependencies = [ "num-traits", "rand", "serde", - "smallvec", + "smallvec 1.13.1", "zeroize", ] @@ -5730,22 +5731,23 @@ checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" [[package]] name = "ort" -version = "2.0.0-rc.9" +version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52afb44b6b0cffa9bf45e4d37e5a4935b0334a51570658e279e9e3e6cf324aa5" +checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" dependencies = [ "half", "libloading 0.8.1", "ndarray", "ort-sys", + "smallvec 2.0.0-alpha.10", "tracing", ] [[package]] name = "ort-sys" -version = "2.0.0-rc.9" +version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c41d7757331aef2d04b9cb09b45583a59217628beaf91895b7e76187b6e8c088" +checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890" dependencies = [ "pkg-config", ] @@ -5858,7 +5860,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "smallvec", + "smallvec 1.13.1", "windows-targets 0.48.5", ] @@ -6776,7 +6778,7 @@ dependencies = [ "fallible-streaming-iterator", "hashlink 0.9.1", "libsqlite3-sys", - "smallvec", + "smallvec 1.13.1", ] [[package]] @@ -7223,7 +7225,7 @@ source = "git+https://github.com/supabase/deno_core?branch=324-supabase#4f0b7554 dependencies = [ "num-bigint", "serde", - "smallvec", + "smallvec 1.13.1", "thiserror 1.0.62", "v8", ] @@ -7436,6 +7438,12 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +[[package]] +name = "smallvec" +version = "2.0.0-alpha.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b" + [[package]] name = "smartstring" version = "1.0.1" @@ -7831,7 +7839,7 @@ dependencies = [ "num-traits", "phf", "serde", - "smallvec", + "smallvec 1.13.1", "smartstring", "stacker", "swc_atoms", @@ -7854,7 +7862,7 @@ dependencies = [ "phf", "rustc-hash 1.1.0", "serde", - "smallvec", + "smallvec 1.13.1", "swc_atoms", "swc_common", "swc_ecma_ast", @@ -7923,7 +7931,7 @@ dependencies = [ "either", "rustc-hash 1.1.0", "serde", - "smallvec", + "smallvec 1.13.1", "swc_atoms", "swc_common", "swc_ecma_ast", @@ -8644,7 +8652,7 @@ dependencies = [ "serde", "serde_json", "sharded-slab", - "smallvec", + "smallvec 1.13.1", "thread_local", "tracing", "tracing-core", @@ -8781,7 +8789,7 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" dependencies = [ - "smallvec", + "smallvec 1.13.1", ] [[package]] @@ -9150,7 +9158,7 @@ dependencies = [ "ron", "rustc-hash 1.1.0", "serde", - "smallvec", + "smallvec 1.13.1", "thiserror 1.0.62", "web-sys", "wgpu-hal", @@ -9191,7 +9199,7 @@ dependencies = [ "range-alloc", "raw-window-handle", "rustc-hash 1.1.0", - "smallvec", + "smallvec 1.13.1", "thiserror 1.0.62", "wasm-bindgen", "web-sys", diff --git a/ext/ai/Cargo.toml b/ext/ai/Cargo.toml index fb4c2b48d..444057b42 100644 --- a/ext/ai/Cargo.toml +++ b/ext/ai/Cargo.toml @@ -17,6 +17,7 @@ ext_ai_v8_utilities.workspace = true anyhow.workspace = true clap = { workspace = true, features = ["derive"] } ctor.workspace = true +dashmap.workspace = true faster-hex.workspace = true futures.workspace = true futures-util = { workspace = true, features = ["io"] } @@ -39,5 +40,5 @@ openblas-src = { version = "0.10", features = ['cblas', 'system'] } rand = "0.8" tokenizers = { version = ">=0.13.4", default-features = false, features = ["onig"] } -ort = { version = "=2.0.0-rc.9", default-features = false, features = ["ndarray", "half", "load-dynamic", "cuda"] } -ort-sys = "=2.0.0-rc.9" +ort = { version = "=2.0.0-rc.10", default-features = false, features = ["ndarray", "half", "load-dynamic", "cuda"] } +ort-sys = "=2.0.0-rc.10" diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 224b0450f..d2d73b4ae 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -21,6 +21,7 @@ use ndarray::ArrayView3; use ndarray::Axis; use ndarray::Ix3; use ort::inputs; +use ort::value::TensorRef; use reqwest::Url; use session::load_session_from_url; use std::cell::RefCell; @@ -39,6 +40,8 @@ use tracing::error; use tracing::trace_span; use tracing::Instrument; +use crate::onnxruntime::session::as_mut_session; + deno_core::extension!( ai, ops = [ @@ -180,6 +183,7 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { -> Result, Error> { let encoded_prompt = tokenizer.encode(prompt, true).map_err(anyhow::Error::msg)?; + let input_ids = encoded_prompt .get_ids() .iter() @@ -198,32 +202,41 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { .map(|i| *i as i64) .collect::>(); - let input_ids_array = Array1::from_iter(input_ids.iter().cloned()); - let input_ids_array = input_ids_array.view().insert_axis(Axis(0)); + // Convert our flattened arrays into 2-dimensional tensors of shape [N, L] -> Since we're not batching 'N' will be always = 1 + let input_ids_array = + TensorRef::from_array_view(([1, input_ids.len()], &*input_ids))?; + let attention_mask_array = TensorRef::from_array_view(( + [1, encoded_prompt.len()], + &*attention_mask, + ))?; - let attention_mask_array = - Array1::from_iter(attention_mask.iter().cloned()); - let attention_mask_array = - attention_mask_array.view().insert_axis(Axis(0)); + let token_type_ids_array = TensorRef::from_array_view(( + [1, encoded_prompt.len()], + &*token_type_ids, + ))?; - let token_type_ids_array = - Array1::from_iter(token_type_ids.iter().cloned()); - let token_type_ids_array = - token_type_ids_array.view().insert_axis(Axis(0)); + let session = unsafe { as_mut_session(&session) }; let outputs = trace_span!("infer_gte").in_scope(|| { session.run(inputs! { - "input_ids" => input_ids_array, - "token_type_ids" => token_type_ids_array, - "attention_mask" => attention_mask_array, - }?) + "input_ids" => input_ids_array, + "token_type_ids" => token_type_ids_array, + "attention_mask" => attention_mask_array, + }) })?; - let embeddings = outputs["last_hidden_state"].try_extract_tensor()?; + let embeddings = outputs["last_hidden_state"].try_extract_array()?; let embeddings = embeddings.into_dimensionality::()?; let result = if do_mean_pooling { - mean_pool(embeddings, attention_mask_array.insert_axis(Axis(2))) + let attention_mask_array_clone = + Array1::from_iter(attention_mask.iter().cloned()); + let attention_mask_array_clone = attention_mask_array_clone + .view() + .insert_axis(Axis(0)) + .insert_axis(Axis(2)); + + mean_pool(embeddings, attention_mask_array_clone) } else { embeddings.into_owned().remove_axis(Axis(0)) }; diff --git a/ext/ai/onnxruntime/mod.rs b/ext/ai/onnxruntime/mod.rs index b1dd136f8..c591c8f21 100644 --- a/ext/ai/onnxruntime/mod.rs +++ b/ext/ai/onnxruntime/mod.rs @@ -32,6 +32,8 @@ use tokio::sync::oneshot; use tracing::debug; use tracing::trace; +use crate::onnxruntime::session::as_mut_session; + #[op2(async)] #[to_v8] pub async fn op_ai_ort_init_session( @@ -103,7 +105,9 @@ pub async fn op_ai_ort_run_session( JsRuntime::op_state_from(state) .borrow_mut() .spawn_cpu_accumul_blocking_scope(move || { - let outputs = match model_session.run(input_values) { + let session = unsafe { as_mut_session(&model_session) }; + + let outputs = match session.run(input_values) { Ok(v) => v, Err(err) => { let _ = tx.send(Err(anyhow::Error::from(err))); diff --git a/ext/ai/onnxruntime/model.rs b/ext/ai/onnxruntime/model.rs index f3a17e6e4..475c2a4ef 100644 --- a/ext/ai/onnxruntime/model.rs +++ b/ext/ai/onnxruntime/model.rs @@ -1,10 +1,9 @@ -use std::sync::Arc; - use anyhow::Result; use deno_core::serde_v8::to_v8; use deno_core::ToV8; use ort::session::Session; use reqwest::Url; +use std::sync::Arc; use super::session::get_session; use super::session::load_session_from_bytes; @@ -31,29 +30,33 @@ pub struct Model { } impl Model { - fn new(session_with_id: SessionWithId) -> Self { - let input_names = session_with_id - .session - .inputs - .iter() - .map(|input| input.name.clone()) - .collect::>(); - - let output_names = session_with_id - .session - .outputs - .iter() - .map(|output| output.name.clone()) - .collect::>(); - - Self { + fn new(session_with_id: SessionWithId) -> Result { + let (input_names, output_names) = { + let session = { session_with_id.session.clone() }; + + let input_names = session + .inputs + .iter() + .map(|input| input.name.clone()) + .collect::>(); + + let output_names = session + .outputs + .iter() + .map(|output| output.name.clone()) + .collect::>(); + + (input_names, output_names) + }; + + Ok(Self { info: ModelInfo { id: session_with_id.id, input_names, output_names, }, session: session_with_id.session, - } + }) } pub fn get_info(&self) -> ModelInfo { @@ -65,18 +68,27 @@ impl Model { } pub async fn from_id(id: &str) -> Option { - get_session(id) - .await - .map(|it| SessionWithId::from((id.to_string(), it))) - .map(Self::new) + let session = { + get_session(id) + .await + .map(|it| SessionWithId::from((id.to_string(), it))) + }; + + let session = session?; + + Self::new(session).ok() } pub async fn from_url(model_url: Url) -> Result { - load_session_from_url(model_url).await.map(Self::new) + let session = load_session_from_url(model_url).await?; + + Self::new(session) } pub async fn from_bytes(model_bytes: &[u8]) -> Result { - load_session_from_bytes(model_bytes).await.map(Self::new) + let session = load_session_from_bytes(model_bytes).await?; + + Self::new(session) } } diff --git a/ext/ai/onnxruntime/session.rs b/ext/ai/onnxruntime/session.rs index 6205e8550..dd0d009b9 100644 --- a/ext/ai/onnxruntime/session.rs +++ b/ext/ai/onnxruntime/session.rs @@ -1,11 +1,10 @@ +use dashmap::DashMap; use deno_core::error::AnyError; use futures::io::AllowStdIo; use once_cell::sync::Lazy; use reqwest::Url; -use std::collections::HashMap; use std::hash::Hasher; use std::sync::Arc; -use tokio::sync::Mutex; use tokio_util::compat::FuturesAsyncWriteCompatExt; use tracing::debug; use tracing::instrument; @@ -25,8 +24,7 @@ use ort::session::Session; use crate::onnx::ensure_onnx_env_init; -static SESSIONS: Lazy>>> = - Lazy::new(|| Mutex::new(HashMap::new())); +static SESSIONS: Lazy>> = Lazy::new(DashMap::new); #[derive(Debug)] pub struct SessionWithId { @@ -74,39 +72,36 @@ pub(crate) fn get_session_builder() -> Result { Ok(builder) } -fn cpu_execution_provider( -) -> Box> { - Box::new( - [ - // NOTE(Nyannacha): See the comment above. This makes `enable_cpu_mem_arena` set to - // False. - // - // Backgrounds: - // [1]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#9-18 - // [2]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#46-50 - CPUExecutionProvider::default().build(), - ] - .into_iter(), - ) +fn cpu_execution_provider() -> ExecutionProviderDispatch { + // NOTE(Nyannacha): See the comment above. This makes `enable_cpu_mem_arena` set to + // False. + // + // Backgrounds: + // [1]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#9-18 + // [2]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#46-50 + CPUExecutionProvider::default().build() } -fn cuda_execution_provider( -) -> Box> { +fn cuda_execution_provider() -> Option { let cuda = CUDAExecutionProvider::default(); - let providers = match cuda.is_available() { - Ok(is_cuda_available) => { - debug!(cuda_support = is_cuda_available); - if is_cuda_available { - vec![cuda.build()] - } else { - vec![] - } - } + let is_cuda_available = cuda.is_available().is_ok_and(|v| v); + debug!(cuda_support = is_cuda_available); + + if is_cuda_available { + Some(cuda.build()) + } else { + None + } +} - _ => vec![], +fn get_execution_providers() -> Vec { + let cpu = cpu_execution_provider(); + + if let Some(cuda) = cuda_execution_provider() { + return [cuda, cpu].to_vec(); }; - Box::new(providers.into_iter().chain(cpu_execution_provider())) + [cpu].to_vec() } fn create_session(model_bytes: &[u8]) -> Result, Error> { @@ -116,13 +111,20 @@ fn create_session(model_bytes: &[u8]) -> Result, Error> { } get_session_builder()? - .with_execution_providers(cuda_execution_provider())? + .with_execution_providers(get_execution_providers())? .commit_from_memory(model_bytes)? }; Ok(Arc::new(session)) } +#[allow(mutable_transmutes)] +#[allow(clippy::mut_from_ref)] +pub(crate) unsafe fn as_mut_session(session: &Arc) -> &mut Session { + // SAFETY: CPU EP https://github.com/pykeio/ort/issues/402#issuecomment-2949993914 + unsafe { std::mem::transmute::<&Session, &mut Session>(&session.clone()) } +} + #[instrument(level = "debug", skip_all, fields(model_bytes = model_bytes.len()), err)] pub(crate) async fn load_session_from_bytes( model_bytes: &[u8], @@ -138,16 +140,14 @@ pub(crate) async fn load_session_from_bytes( faster_hex::hex_string(&hasher.finish().to_be_bytes()) }; - let mut sessions = SESSIONS.lock().await; - - if let Some(session) = sessions.get(&session_id) { + if let Some(session) = SESSIONS.get(&session_id) { return Ok((session_id, session.clone()).into()); } trace!(session_id, "new session"); let session = create_session(model_bytes)?; - sessions.insert(session_id.clone(), session.clone()); + SESSIONS.insert(session_id.clone(), session.clone()); Ok((session_id, session).into()) } @@ -158,9 +158,7 @@ pub(crate) async fn load_session_from_url( ) -> Result { let session_id = fxhash::hash(model_url.as_str()).to_string(); - let mut sessions = SESSIONS.lock().await; - - if let Some(session) = sessions.get(&session_id) { + if let Some(session) = SESSIONS.get(&session_id) { debug!(session_id, "use existing session"); return Ok((session_id, session.clone()).into()); } @@ -176,22 +174,23 @@ pub(crate) async fn load_session_from_url( let session = create_session(model_bytes.as_slice())?; debug!(session_id, "new session"); - sessions.insert(session_id.clone(), session.clone()); + SESSIONS.insert(session_id.clone(), session.clone()); Ok((session_id, session).into()) } pub(crate) async fn get_session(id: &str) -> Option> { - SESSIONS.lock().await.get(id).cloned() + SESSIONS.get(id).map(|value| value.pair().1.clone()) } pub async fn cleanup() -> Result { let mut remove_counter = 0; { - let mut guard = SESSIONS.lock().await; + //let mut guard = SESSIONS.lock().await; let mut to_be_removed = vec![]; - for (key, session) in &mut *guard { + for v in SESSIONS.iter() { + let (key, session) = v.pair(); // Since we're currently referencing the session at this point // It also will increments the counter, so we need to check: counter > 1 if Arc::strong_count(session) > 1 { @@ -202,7 +201,7 @@ pub async fn cleanup() -> Result { } for key in to_be_removed { - let old_store = guard.remove(&key); + let old_store = SESSIONS.remove(&key); debug_assert!(old_store.is_some()); remove_counter += 1; diff --git a/ext/ai/onnxruntime/tensor.rs b/ext/ai/onnxruntime/tensor.rs index 86b2e7ae8..a3758a574 100644 --- a/ext/ai/onnxruntime/tensor.rs +++ b/ext/ai/onnxruntime/tensor.rs @@ -14,6 +14,7 @@ use ort::memory::MemoryInfo; use ort::memory::MemoryType; use ort::session::SessionInputValue; use ort::tensor::PrimitiveTensorElementType; +use ort::tensor::Shape; use ort::tensor::TensorElementType; use ort::value::DynValue; use ort::value::DynValueTypeMarker; @@ -32,7 +33,7 @@ macro_rules! v8_slice_from { (tensor::<$type:ident>($tensor:expr)) => {{ // We must ensure there's some detection to avoid `null pointer` errors // https://github.com/pykeio/ort/issues/185 - let n_detections = $tensor.shape()?[0]; + let n_detections = $tensor.shape()[0]; if n_detections == 0 { let buf_store = v8::ArrayBuffer::new_backing_store_from_vec(vec![]).make_shared(); @@ -43,7 +44,7 @@ macro_rules! v8_slice_from { buffer_slice } else { let (_, raw_tensor) = $tensor - .try_extract_raw_tensor_mut::<$type>() + .try_extract_tensor_mut::<$type>() .map_err(AnyError::from)?; let tensor_ptr = raw_tensor.as_ptr(); @@ -113,6 +114,24 @@ pub enum JsTensorType { Uint64, /// Brain 16-bit floating point number, equivalent to [`half::bf16`] (requires the `half` feature). Bfloat16, + Complex64, + Complex128, + /// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values and no infinite + /// values. + Float8E4M3FN, + /// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values, no infinite + /// values, and no negative zero. + Float8E4M3FNUZ, + /// 8-bit floating point number with 5 exponent bits and 2 mantissa bits. + Float8E5M2, + /// 8-bit floating point number with 5 exponent bits and 2 mantissa bits, with only NaN values, no infinite + /// values, and no negative zero. + Float8E5M2FNUZ, + /// 4-bit unsigned integer. + Uint4, + /// 4-bit signed integer. + Int4, + Undefined, } #[derive(Serialize, Deserialize)] @@ -183,7 +202,7 @@ impl JsTensor { TensorRefMut::::from_raw( memory_info, data.as_mut_ptr() as *mut c_void, - self.dims, + Shape::new(self.dims), ) }?; @@ -205,7 +224,7 @@ impl JsTensor { )); }; - Tensor::from_string_array((self.dims, data))?.into() + Tensor::from_string_array((self.dims, data.as_slice()))?.into() } TensorElementType::Int8 => self.extract_ort_tensor_ref::()?.into(), TensorElementType::Uint8 => self.extract_ort_tensor_ref::()?.into(), @@ -216,11 +235,8 @@ impl JsTensor { TensorElementType::Int64 => self.extract_ort_tensor_ref::()?.into(), TensorElementType::Uint64 => self.extract_ort_tensor_ref::()?.into(), TensorElementType::Bool => self.extract_ort_tensor_ref::()?.into(), - TensorElementType::Float16 => { - return Err(anyhow!("'half::f16' is not supported by JS tensor.")) - } - TensorElementType::Bfloat16 => { - return Err(anyhow!("'half::bf16' is not supported by JS tensor.")) + other => { + return Err(anyhow!("'{other:?}' is not supported by JS tensor.")) } }; @@ -243,7 +259,7 @@ impl ToJsTensor { "JS only support 'ort::Value' of 'Tensor' type, got '{value:?}'." )); }; - let tensor_shape = value.shape()?; + let tensor_shape = value.shape().to_vec(); let buffer_slice = match tensor_type { TensorElementType::Float32 => v8_slice_from!(tensor::(value)), @@ -257,9 +273,9 @@ impl ToJsTensor { TensorElementType::Int64 => v8_slice_from!(tensor::(value)), TensorElementType::Uint64 => v8_slice_from!(tensor::(value)), TensorElementType::Bool => v8_slice_from!(tensor::(value)), - TensorElementType::String => todo!(), - TensorElementType::Float16 => todo!(), - TensorElementType::Bfloat16 => todo!(), + other => { + return Err(anyhow!("'{other:?}' is not supported by JS tensor.")) + } }; Ok(Self {