@@ -13,21 +13,22 @@ pub mod onnx {
13
13
use dr_transform:: converter:: { BatchPredictionRequestToTorchTensorConverter , Converter } ;
14
14
use itertools:: Itertools ;
15
15
use log:: { debug, info} ;
16
- use ort:: environment:: Environment ;
17
- use ort:: session:: Session ;
18
- use ort:: tensor:: InputTensor ;
19
- use ort:: { ExecutionProvider , GraphOptimizationLevel , SessionBuilder } ;
16
+ use dr_transform:: ort:: environment:: Environment ;
17
+ use dr_transform:: ort:: session:: Session ;
18
+ use dr_transform:: ort:: tensor:: InputTensor ;
19
+ use dr_transform:: ort:: { ExecutionProvider , GraphOptimizationLevel , SessionBuilder } ;
20
+ use dr_transform:: ort:: LoggingLevel ;
20
21
use serde_json:: Value ;
21
22
use std:: fmt:: { Debug , Display } ;
22
23
use std:: sync:: Arc ;
23
24
use std:: { fmt, fs} ;
24
25
use tokio:: time:: Instant ;
25
-
26
26
lazy_static ! {
27
27
pub static ref ENVIRONMENT : Arc <Environment > = Arc :: new(
28
28
Environment :: builder( )
29
29
. with_name( "onnx home" )
30
- . with_log_level( ort:: LoggingLevel :: Error )
30
+ . with_log_level( LoggingLevel :: Error )
31
+ . with_global_thread_pool( ARGS . onnx_global_thread_pool_options. clone( ) )
31
32
. build( )
32
33
. unwrap( )
33
34
) ;
@@ -101,23 +102,30 @@ pub mod onnx {
101
102
let meta_info = format ! ( "{}/{}/{}" , ARGS . model_dir[ idx] , version, META_INFO ) ;
102
103
let mut builder = SessionBuilder :: new ( & ENVIRONMENT ) ?
103
104
. with_optimization_level ( GraphOptimizationLevel :: Level3 ) ?
104
- . with_parallel_execution ( ARGS . onnx_use_parallel_mode == "true" ) ?
105
- . with_inter_threads (
106
- utils:: get_config_or (
107
- model_config,
108
- "inter_op_parallelism" ,
109
- & ARGS . inter_op_parallelism [ idx] ,
110
- )
111
- . parse ( ) ?,
112
- ) ?
113
- . with_intra_threads (
114
- utils:: get_config_or (
115
- model_config,
116
- "intra_op_parallelism" ,
117
- & ARGS . intra_op_parallelism [ idx] ,
118
- )
119
- . parse ( ) ?,
120
- ) ?
105
+ . with_parallel_execution ( ARGS . onnx_use_parallel_mode == "true" ) ?;
106
+ if ARGS . onnx_global_thread_pool_options . is_empty ( ) {
107
+ builder = builder
108
+ . with_inter_threads (
109
+ utils:: get_config_or (
110
+ model_config,
111
+ "inter_op_parallelism" ,
112
+ & ARGS . inter_op_parallelism [ idx] ,
113
+ )
114
+ . parse ( ) ?,
115
+ ) ?
116
+ . with_intra_threads (
117
+ utils:: get_config_or (
118
+ model_config,
119
+ "intra_op_parallelism" ,
120
+ & ARGS . intra_op_parallelism [ idx] ,
121
+ )
122
+ . parse ( ) ?,
123
+ ) ?;
124
+ }
125
+ else {
126
+ builder = builder. with_disable_per_session_threads ( ) ?;
127
+ }
128
+ builder = builder
121
129
. with_memory_pattern ( ARGS . onnx_use_memory_pattern == "true" ) ?
122
130
. with_execution_providers ( & OnnxModel :: ep_choices ( ) ) ?;
123
131
match & ARGS . profiling {
0 commit comments