Skip to content

Commit e0c0ea6

Browse files
committed
stamp: solving ort api breaking changes
1 parent 58733c0 commit e0c0ea6

File tree

5 files changed

+128
-76
lines changed

5 files changed

+128
-76
lines changed

ext/ai/lib.rs

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use ndarray::ArrayView3;
2121
use ndarray::Axis;
2222
use ndarray::Ix3;
2323
use ort::inputs;
24+
use ort::value::TensorRef;
2425
use reqwest::Url;
2526
use session::load_session_from_url;
2627
use std::cell::RefCell;
@@ -180,6 +181,7 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
180181
-> Result<Vec<f32>, Error> {
181182
let encoded_prompt =
182183
tokenizer.encode(prompt, true).map_err(anyhow::Error::msg)?;
184+
183185
let input_ids = encoded_prompt
184186
.get_ids()
185187
.iter()
@@ -198,32 +200,36 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
198200
.map(|i| *i as i64)
199201
.collect::<Vec<_>>();
200202

201-
let input_ids_array = Array1::from_iter(input_ids.iter().cloned());
202-
let input_ids_array = input_ids_array.view().insert_axis(Axis(0));
203+
let input_ids_array = TensorRef::from_array_view(([input_ids.len(), 1], &*input_ids))?;
204+
let attention_mask_array = TensorRef::from_array_view(([1, encoded_prompt.len()], &*attention_mask))?;
205+
let token_type_ids_array = TensorRef::from_array_view(([1, encoded_prompt.len()], &*token_type_ids))?;
203206

204-
let attention_mask_array =
205-
Array1::from_iter(attention_mask.iter().cloned());
206-
let attention_mask_array =
207-
attention_mask_array.view().insert_axis(Axis(0));
208207

209-
let token_type_ids_array =
210-
Array1::from_iter(token_type_ids.iter().cloned());
211-
let token_type_ids_array =
212-
token_type_ids_array.view().insert_axis(Axis(0));
208+
let Ok(mut guard) = session.lock() else {
209+
let err = anyhow!("failed to lock session");
210+
error!(reason = ?err);
211+
return Err(err);
212+
};
213213

214214
let outputs = trace_span!("infer_gte").in_scope(|| {
215-
session.run(inputs! {
216-
"input_ids" => input_ids_array,
217-
"token_type_ids" => token_type_ids_array,
218-
"attention_mask" => attention_mask_array,
219-
}?)
215+
guard.run(inputs! {
216+
"input_ids" => input_ids_array,
217+
"token_type_ids" => token_type_ids_array,
218+
"attention_mask" => attention_mask_array,
219+
})
220220
})?;
221221

222-
let embeddings = outputs["last_hidden_state"].try_extract_tensor()?;
222+
let embeddings = outputs["last_hidden_state"].try_extract_array()?;
223223
let embeddings = embeddings.into_dimensionality::<Ix3>()?;
224224

225225
let result = if do_mean_pooling {
226-
mean_pool(embeddings, attention_mask_array.insert_axis(Axis(2)))
226+
let attention_mask_array_clone= Array1::from_iter(attention_mask.iter().cloned());
227+
let attention_mask_array_clone= attention_mask_array_clone.view()
228+
.insert_axis(Axis(0))
229+
.insert_axis(Axis(2));
230+
231+
println!("attention_mask: {attention_mask_array_clone:?}");
232+
mean_pool(embeddings, attention_mask_array_clone)
227233
} else {
228234
embeddings.into_owned().remove_axis(Axis(0))
229235
};

ext/ai/onnxruntime/mod.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::cell::RefCell;
1010
use std::collections::HashMap;
1111
use std::rc::Rc;
1212
use std::sync::Arc;
13+
use std::sync::Mutex;
1314

1415
use anyhow::anyhow;
1516
use anyhow::Context;
@@ -56,7 +57,7 @@ pub async fn op_ai_ort_init_session(
5657

5758
let mut state = state.borrow_mut();
5859
let mut sessions =
59-
{ state.try_take::<Vec<Arc<Session>>>().unwrap_or_default() };
60+
{ state.try_take::<Vec<Arc<Mutex<Session>>>>().unwrap_or_default() };
6061

6162
sessions.push(model.get_session());
6263
state.put(sessions);
@@ -103,7 +104,12 @@ pub async fn op_ai_ort_run_session(
103104
JsRuntime::op_state_from(state)
104105
.borrow_mut()
105106
.spawn_cpu_accumul_blocking_scope(move || {
106-
let outputs = match model_session.run(input_values) {
107+
let Ok(mut session_guard) = model_session.lock() else {
108+
let _ = tx.send(Err(anyhow!("failed to lock model session")));
109+
return;
110+
};
111+
112+
let outputs = match session_guard.run(input_values) {
107113
Ok(v) => v,
108114
Err(err) => {
109115
let _ = tx.send(Err(anyhow::Error::from(err)));

ext/ai/onnxruntime/model.rs

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use std::sync::Arc;
2+
use std::sync::Mutex;
23

4+
use anyhow::anyhow;
35
use anyhow::Result;
46
use deno_core::serde_v8::to_v8;
57
use deno_core::ToV8;
@@ -27,56 +29,73 @@ impl std::fmt::Display for ModelInfo {
2729
#[derive(Debug)]
2830
pub struct Model {
2931
info: ModelInfo,
30-
session: Arc<Session>,
32+
session: Arc<Mutex<Session>>,
3133
}
3234

3335
impl Model {
34-
fn new(session_with_id: SessionWithId) -> Self {
35-
let input_names = session_with_id
36-
.session
36+
fn new(session_with_id: SessionWithId) -> Result<Self> {
37+
let (input_names, output_names) = {
38+
let Ok(session_guard) = session_with_id.session.lock() else {
39+
return Err(anyhow!("Could not lock model session {}", session_with_id.id));
40+
};
41+
42+
let input_names = session_guard
3743
.inputs
3844
.iter()
3945
.map(|input| input.name.clone())
4046
.collect::<Vec<_>>();
4147

42-
let output_names = session_with_id
43-
.session
48+
let output_names = session_guard
4449
.outputs
4550
.iter()
4651
.map(|output| output.name.clone())
4752
.collect::<Vec<_>>();
4853

49-
Self {
54+
(input_names, output_names)
55+
};
56+
57+
Ok(Self {
5058
info: ModelInfo {
5159
id: session_with_id.id,
5260
input_names,
5361
output_names,
5462
},
5563
session: session_with_id.session,
56-
}
64+
})
5765
}
5866

5967
pub fn get_info(&self) -> ModelInfo {
6068
self.info.clone()
6169
}
6270

63-
pub fn get_session(&self) -> Arc<Session> {
71+
pub fn get_session(&self) -> Arc<Mutex<Session>> {
6472
self.session.clone()
6573
}
6674

6775
pub async fn from_id(id: &str) -> Option<Self> {
68-
get_session(id)
76+
let session = {
77+
get_session(id)
6978
.await
7079
.map(|it| SessionWithId::from((id.to_string(), it)))
71-
.map(Self::new)
80+
};
81+
82+
let Some(session) = session else {
83+
return None;
84+
};
85+
86+
Self::new(session).ok()
7287
}
7388

7489
pub async fn from_url(model_url: Url) -> Result<Self> {
75-
load_session_from_url(model_url).await.map(Self::new)
90+
let session = load_session_from_url(model_url).await?;
91+
92+
Self::new(session)
7693
}
7794

7895
pub async fn from_bytes(model_bytes: &[u8]) -> Result<Self> {
79-
load_session_from_bytes(model_bytes).await.map(Self::new)
96+
let session = load_session_from_bytes(model_bytes).await?;
97+
98+
Self::new(session)
8099
}
81100
}
82101

ext/ai/onnxruntime/session.rs

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use reqwest::Url;
55
use std::collections::HashMap;
66
use std::hash::Hasher;
77
use std::sync::Arc;
8-
use tokio::sync::Mutex;
8+
use std::sync::Mutex;
9+
use tokio::sync::Mutex as AsyncMutex;
910
use tokio_util::compat::FuturesAsyncWriteCompatExt;
1011
use tracing::debug;
1112
use tracing::instrument;
@@ -25,17 +26,17 @@ use ort::session::Session;
2526

2627
use crate::onnx::ensure_onnx_env_init;
2728

28-
static SESSIONS: Lazy<Mutex<HashMap<String, Arc<Session>>>> =
29-
Lazy::new(|| Mutex::new(HashMap::new()));
29+
static SESSIONS: Lazy<AsyncMutex<HashMap<String, Arc<Mutex<Session>>>>> =
30+
Lazy::new(|| AsyncMutex::new(HashMap::new()));
3031

3132
#[derive(Debug)]
3233
pub struct SessionWithId {
3334
pub(crate) id: String,
34-
pub(crate) session: Arc<Session>,
35+
pub(crate) session: Arc<Mutex<Session>>,
3536
}
3637

37-
impl From<(String, Arc<Session>)> for SessionWithId {
38-
fn from(value: (String, Arc<Session>)) -> Self {
38+
impl From<(String, Arc<Mutex<Session>>)> for SessionWithId {
39+
fn from(value: (String, Arc<Mutex<Session>>)) -> Self {
3940
Self {
4041
id: value.0,
4142
session: value.1,
@@ -50,7 +51,7 @@ impl std::fmt::Display for SessionWithId {
5051
}
5152

5253
impl SessionWithId {
53-
pub fn into_split(self) -> (String, Arc<Session>) {
54+
pub fn into_split(self) -> (String, Arc<Mutex<Session>>) {
5455
(self.id, self.session)
5556
}
5657
}
@@ -74,53 +75,51 @@ pub(crate) fn get_session_builder() -> Result<SessionBuilder, AnyError> {
7475
Ok(builder)
7576
}
7677

77-
fn cpu_execution_provider(
78-
) -> Box<dyn Iterator<Item = ExecutionProviderDispatch>> {
79-
Box::new(
80-
[
81-
// NOTE(Nyannacha): See the comment above. This makes `enable_cpu_mem_arena` set to
82-
// False.
83-
//
84-
// Backgrounds:
85-
// [1]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#9-18
86-
// [2]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#46-50
87-
CPUExecutionProvider::default().build(),
88-
]
89-
.into_iter(),
90-
)
78+
fn cpu_execution_provider() -> ExecutionProviderDispatch {
79+
// NOTE(Nyannacha): See the comment above. This makes `enable_cpu_mem_arena` set to
80+
// False.
81+
//
82+
// Backgrounds:
83+
// [1]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#9-18
84+
// [2]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#46-50
85+
CPUExecutionProvider::default().build()
9186
}
9287

93-
fn cuda_execution_provider(
94-
) -> Box<dyn Iterator<Item = ExecutionProviderDispatch>> {
88+
fn cuda_execution_provider() -> Option<ExecutionProviderDispatch> {
9589
let cuda = CUDAExecutionProvider::default();
96-
let providers = match cuda.is_available() {
97-
Ok(is_cuda_available) => {
98-
debug!(cuda_support = is_cuda_available);
99-
if is_cuda_available {
100-
vec![cuda.build()]
101-
} else {
102-
vec![]
103-
}
104-
}
90+
let is_cuda_available = cuda.is_available().is_ok_and(|v| v);
91+
debug!(cuda_support = is_cuda_available);
92+
93+
if is_cuda_available {
94+
Some(cuda.build())
95+
}else{
96+
None
97+
}
98+
}
99+
100+
fn get_execution_providers(
101+
) -> Vec<ExecutionProviderDispatch> {
102+
let cpu = cpu_execution_provider();
105103

106-
_ => vec![],
104+
if let Some(cuda) = cuda_execution_provider() {
105+
return [cuda, cpu].to_vec();
107106
};
108107

109-
Box::new(providers.into_iter().chain(cpu_execution_provider()))
108+
[cpu].to_vec()
110109
}
111110

112-
fn create_session(model_bytes: &[u8]) -> Result<Arc<Session>, Error> {
111+
fn create_session(model_bytes: &[u8]) -> Result<Arc<Mutex<Session>>, Error> {
113112
let session = {
114113
if let Some(err) = ensure_onnx_env_init() {
115114
return Err(anyhow!("failed to create onnx environment: {err}"));
116115
}
117116

118117
get_session_builder()?
119-
.with_execution_providers(cuda_execution_provider())?
118+
.with_execution_providers(get_execution_providers())?
120119
.commit_from_memory(model_bytes)?
121120
};
122121

123-
Ok(Arc::new(session))
122+
Ok(Arc::new(Mutex::new(session)))
124123
}
125124

126125
#[instrument(level = "debug", skip_all, fields(model_bytes = model_bytes.len()), err)]
@@ -181,7 +180,7 @@ pub(crate) async fn load_session_from_url(
181180
Ok((session_id, session).into())
182181
}
183182

184-
pub(crate) async fn get_session(id: &str) -> Option<Arc<Session>> {
183+
pub(crate) async fn get_session(id: &str) -> Option<Arc<Mutex<Session>>> {
185184
SESSIONS.lock().await.get(id).cloned()
186185
}
187186

0 commit comments

Comments
 (0)