@@ -3,19 +3,29 @@ pub(crate) mod onnx;
3
3
pub ( crate ) mod session;
4
4
mod tensor;
5
5
6
- use std:: { borrow:: Cow , collections:: HashMap } ;
6
+ use std:: { borrow:: Cow , cell :: RefCell , collections:: HashMap , rc :: Rc , sync :: Arc } ;
7
7
8
- use anyhow:: Result ;
9
- use deno_core:: op2;
8
+ use anyhow:: { anyhow , Result } ;
9
+ use deno_core:: { op2, OpState } ;
10
10
11
11
use model_session:: { ModelInfo , ModelSession } ;
12
+ use ort:: Session ;
12
13
use tensor:: { JsTensor , ToJsTensor } ;
13
14
14
15
#[ op2]
15
16
#[ to_v8]
16
- pub fn op_sb_ai_ort_init_session ( #[ buffer] model_bytes : & [ u8 ] ) -> Result < ModelInfo > {
17
+ pub fn op_sb_ai_ort_init_session (
18
+ state : Rc < RefCell < OpState > > ,
19
+ #[ buffer] model_bytes : & [ u8 ] ,
20
+ ) -> Result < ModelInfo > {
21
+ let mut state = state. borrow_mut ( ) ;
17
22
let model_info = ModelSession :: from_bytes ( model_bytes) ?;
18
23
24
+ let mut sessions = { state. try_take :: < Vec < Arc < Session > > > ( ) . unwrap_or_default ( ) } ;
25
+
26
+ sessions. push ( model_info. inner ( ) ) ;
27
+ state. put ( sessions) ;
28
+
19
29
Ok ( model_info. info ( ) )
20
30
}
21
31
@@ -25,10 +35,11 @@ pub fn op_sb_ai_ort_run_session(
25
35
#[ string] model_id : String ,
26
36
#[ serde] input_values : HashMap < String , JsTensor > ,
27
37
) -> Result < HashMap < String , ToJsTensor > > {
28
- let model = ModelSession :: from_id ( model_id) . unwrap ( ) ;
38
+ let model = ModelSession :: from_id ( model_id. to_owned ( ) )
39
+ . ok_or ( anyhow ! ( "could not found session for id={model_id:?}" ) ) ?;
40
+
29
41
let model_session = model. inner ( ) ;
30
42
31
- // println!("{model_session:?}");
32
43
let input_values = input_values
33
44
. into_iter ( )
34
45
. map ( |( key, value) | {
@@ -44,7 +55,9 @@ pub fn op_sb_ai_ort_run_session(
44
55
// We need to `pop` over outputs to get 'value' ownership, since keys are attached to 'model_session' lifetime
45
56
// it can't be iterated with `into_iter()`
46
57
for _ in 0 ..outputs. len ( ) {
47
- let ( key, value) = outputs. pop_first ( ) . unwrap ( ) ;
58
+ let ( key, value) = outputs. pop_first ( ) . ok_or ( anyhow ! (
59
+ "could not retrieve output value from model session"
60
+ ) ) ?;
48
61
49
62
let value = ToJsTensor :: from_ort_tensor ( value) ?;
50
63
0 commit comments