1
+ use dashmap:: DashMap ;
1
2
use deno_core:: error:: AnyError ;
2
3
use futures:: io:: AllowStdIo ;
3
4
use once_cell:: sync:: Lazy ;
4
5
use reqwest:: Url ;
5
- use std:: collections:: HashMap ;
6
6
use std:: hash:: Hasher ;
7
7
use std:: sync:: Arc ;
8
8
use std:: sync:: Mutex ;
9
- use tokio:: sync:: Mutex as AsyncMutex ;
10
9
use tokio_util:: compat:: FuturesAsyncWriteCompatExt ;
11
10
use tracing:: debug;
12
11
use tracing:: instrument;
@@ -26,8 +25,8 @@ use ort::session::Session;
26
25
27
26
use crate :: onnx:: ensure_onnx_env_init;
28
27
29
- static SESSIONS : Lazy < AsyncMutex < HashMap < String , Arc < Mutex < Session > > > > > =
30
- Lazy :: new ( || AsyncMutex :: new ( HashMap :: new ( ) ) ) ;
28
+ static SESSIONS : Lazy < DashMap < String , Arc < Mutex < Session > > > > =
29
+ Lazy :: new ( DashMap :: new) ;
31
30
32
31
#[ derive( Debug ) ]
33
32
pub struct SessionWithId {
@@ -136,16 +135,14 @@ pub(crate) async fn load_session_from_bytes(
136
135
faster_hex:: hex_string ( & hasher. finish ( ) . to_be_bytes ( ) )
137
136
} ;
138
137
139
- let mut sessions = SESSIONS . lock ( ) . await ;
140
-
141
- if let Some ( session) = sessions. get ( & session_id) {
138
+ if let Some ( session) = SESSIONS . get ( & session_id) {
142
139
return Ok ( ( session_id, session. clone ( ) ) . into ( ) ) ;
143
140
}
144
141
145
142
trace ! ( session_id, "new session" ) ;
146
143
let session = create_session ( model_bytes) ?;
147
144
148
- sessions . insert ( session_id. clone ( ) , session. clone ( ) ) ;
145
+ SESSIONS . insert ( session_id. clone ( ) , session. clone ( ) ) ;
149
146
150
147
Ok ( ( session_id, session) . into ( ) )
151
148
}
@@ -156,9 +153,7 @@ pub(crate) async fn load_session_from_url(
156
153
) -> Result < SessionWithId , Error > {
157
154
let session_id = fxhash:: hash ( model_url. as_str ( ) ) . to_string ( ) ;
158
155
159
- let mut sessions = SESSIONS . lock ( ) . await ;
160
-
161
- if let Some ( session) = sessions. get ( & session_id) {
156
+ if let Some ( session) = SESSIONS . get ( & session_id) {
162
157
debug ! ( session_id, "use existing session" ) ;
163
158
return Ok ( ( session_id, session. clone ( ) ) . into ( ) ) ;
164
159
}
@@ -174,22 +169,23 @@ pub(crate) async fn load_session_from_url(
174
169
let session = create_session ( model_bytes. as_slice ( ) ) ?;
175
170
176
171
debug ! ( session_id, "new session" ) ;
177
- sessions . insert ( session_id. clone ( ) , session. clone ( ) ) ;
172
+ SESSIONS . insert ( session_id. clone ( ) , session. clone ( ) ) ;
178
173
179
174
Ok ( ( session_id, session) . into ( ) )
180
175
}
181
176
182
177
pub ( crate ) async fn get_session ( id : & str ) -> Option < Arc < Mutex < Session > > > {
183
- SESSIONS . lock ( ) . await . get ( id ) . cloned ( )
178
+ SESSIONS . get ( id ) . map ( |value| value . pair ( ) . 1 . clone ( ) )
184
179
}
185
180
186
181
pub async fn cleanup ( ) -> Result < usize , AnyError > {
187
182
let mut remove_counter = 0 ;
188
183
{
189
- let mut guard = SESSIONS . lock ( ) . await ;
184
+ // let mut guard = SESSIONS.lock().await;
190
185
let mut to_be_removed = vec ! [ ] ;
191
186
192
- for ( key, session) in & mut * guard {
187
+ for v in SESSIONS . iter ( ) {
188
+ let ( key, session) = v. pair ( ) ;
193
189
// Since we're currently referencing the session at this point
194
190
// It also will increments the counter, so we need to check: counter > 1
195
191
if Arc :: strong_count ( session) > 1 {
@@ -200,7 +196,7 @@ pub async fn cleanup() -> Result<usize, AnyError> {
200
196
}
201
197
202
198
for key in to_be_removed {
203
- let old_store = guard . remove ( & key) ;
199
+ let old_store = SESSIONS . remove ( & key) ;
204
200
debug_assert ! ( old_store. is_some( ) ) ;
205
201
206
202
remove_counter += 1 ;
0 commit comments