Skip to content

Commit 8faa059

Browse files
committed
feat: using dashmap to store sessions
1 parent 3bf1963 commit 8faa059

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ext/ai/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ ext_ai_v8_utilities.workspace = true
1717
anyhow.workspace = true
1818
clap = { workspace = true, features = ["derive"] }
1919
ctor.workspace = true
20+
dashmap.workspace = true
2021
faster-hex.workspace = true
2122
futures.workspace = true
2223
futures-util = { workspace = true, features = ["io"] }

ext/ai/onnxruntime/session.rs

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1+
use dashmap::DashMap;
12
use deno_core::error::AnyError;
23
use futures::io::AllowStdIo;
34
use once_cell::sync::Lazy;
45
use reqwest::Url;
5-
use std::collections::HashMap;
66
use std::hash::Hasher;
77
use std::sync::Arc;
88
use std::sync::Mutex;
9-
use tokio::sync::Mutex as AsyncMutex;
109
use tokio_util::compat::FuturesAsyncWriteCompatExt;
1110
use tracing::debug;
1211
use tracing::instrument;
@@ -26,8 +25,8 @@ use ort::session::Session;
2625

2726
use crate::onnx::ensure_onnx_env_init;
2827

29-
static SESSIONS: Lazy<AsyncMutex<HashMap<String, Arc<Mutex<Session>>>>> =
30-
Lazy::new(|| AsyncMutex::new(HashMap::new()));
28+
static SESSIONS: Lazy<DashMap<String, Arc<Mutex<Session>>>> =
29+
Lazy::new(DashMap::new);
3130

3231
#[derive(Debug)]
3332
pub struct SessionWithId {
@@ -136,16 +135,14 @@ pub(crate) async fn load_session_from_bytes(
136135
faster_hex::hex_string(&hasher.finish().to_be_bytes())
137136
};
138137

139-
let mut sessions = SESSIONS.lock().await;
140-
141-
if let Some(session) = sessions.get(&session_id) {
138+
if let Some(session) = SESSIONS.get(&session_id) {
142139
return Ok((session_id, session.clone()).into());
143140
}
144141

145142
trace!(session_id, "new session");
146143
let session = create_session(model_bytes)?;
147144

148-
sessions.insert(session_id.clone(), session.clone());
145+
SESSIONS.insert(session_id.clone(), session.clone());
149146

150147
Ok((session_id, session).into())
151148
}
@@ -156,9 +153,7 @@ pub(crate) async fn load_session_from_url(
156153
) -> Result<SessionWithId, Error> {
157154
let session_id = fxhash::hash(model_url.as_str()).to_string();
158155

159-
let mut sessions = SESSIONS.lock().await;
160-
161-
if let Some(session) = sessions.get(&session_id) {
156+
if let Some(session) = SESSIONS.get(&session_id) {
162157
debug!(session_id, "use existing session");
163158
return Ok((session_id, session.clone()).into());
164159
}
@@ -174,22 +169,23 @@ pub(crate) async fn load_session_from_url(
174169
let session = create_session(model_bytes.as_slice())?;
175170

176171
debug!(session_id, "new session");
177-
sessions.insert(session_id.clone(), session.clone());
172+
SESSIONS.insert(session_id.clone(), session.clone());
178173

179174
Ok((session_id, session).into())
180175
}
181176

182177
pub(crate) async fn get_session(id: &str) -> Option<Arc<Mutex<Session>>> {
183-
SESSIONS.lock().await.get(id).cloned()
178+
SESSIONS.get(id).map(|value| value.pair().1.clone())
184179
}
185180

186181
pub async fn cleanup() -> Result<usize, AnyError> {
187182
let mut remove_counter = 0;
188183
{
189-
let mut guard = SESSIONS.lock().await;
184+
//let mut guard = SESSIONS.lock().await;
190185
let mut to_be_removed = vec![];
191186

192-
for (key, session) in &mut *guard {
187+
for v in SESSIONS.iter() {
188+
let (key, session) = v.pair();
193189
// Since we're currently referencing the session at this point
194190
// It also will increments the counter, so we need to check: counter > 1
195191
if Arc::strong_count(session) > 1 {
@@ -200,7 +196,7 @@ pub async fn cleanup() -> Result<usize, AnyError> {
200196
}
201197

202198
for key in to_be_removed {
203-
let old_store = guard.remove(&key);
199+
let old_store = SESSIONS.remove(&key);
204200
debug_assert!(old_store.is_some());
205201

206202
remove_counter += 1;

0 commit comments

Comments
 (0)