@@ -24,6 +24,8 @@ class OnnxOutputContext:
2424
2525
2626class OnnxModel (Generic [T ]):
27+ EXPOSED_SESSION_OPTIONS = ("enable_cpu_mem_arena" ,)
28+
2729 @classmethod
2830 def _get_worker_class (cls ) -> Type ["EmbeddingWorker[T]" ]:
2931 raise NotImplementedError ("Subclasses must implement this method" )
@@ -60,6 +62,7 @@ def _load_onnx_model(
6062 providers : Optional [Sequence [OnnxProvider ]] = None ,
6163 cuda : bool = False ,
6264 device_id : Optional [int ] = None ,
65+ extra_session_options : Optional [dict [str , Any ]] = None ,
6366 ) -> None :
6467 model_path = model_dir / model_file
6568 # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
@@ -99,6 +102,9 @@ def _load_onnx_model(
99102 so .intra_op_num_threads = threads
100103 so .inter_op_num_threads = threads
101104
105+ if extra_session_options is not None :
106+ self .add_extra_session_options (so , extra_session_options )
107+
102108 self .model = ort .InferenceSession (
103109 str (model_path ), providers = onnx_providers , sess_options = so
104110 )
@@ -113,6 +119,38 @@ def _load_onnx_model(
113119 RuntimeWarning ,
114120 )
115121
122+ @classmethod
123+ def _select_exposed_session_options (cls , model_kwargs : dict [str , Any ]) -> dict [str , Any ]:
124+ """A convenience method to select the exposed session options in models
125+
126+ Args:
127+ model_kwargs (dict[str, Any]): The model kwargs.
128+
129+ Returns:
130+ dict[str, Any]: a dict with filtered exposed session options.
131+ """
132+ return {k : v for k , v in model_kwargs .items () if k in cls .EXPOSED_SESSION_OPTIONS }
133+
134+ @classmethod
135+ def add_extra_session_options (
136+ cls , session_options : ort .SessionOptions , extra_options : dict [str , Any ]
137+ ) -> None :
138+ """Add extra session options to the existing options object in-place
139+
140+ Args:
141+ session_options (ort.SessionOptions): The existing session options object.
142+ extra_options (dict[str, Any]): The extra session options available in cls.EXPOSED_SESSION_OPTIONS.
143+
144+ Returns:
145+ None
146+ """
147+ for option in extra_options :
148+ assert (
149+ option in cls .EXPOSED_SESSION_OPTIONS
150+ ), f"{ option } is unknown or not exposed (exposed options: { cls .EXPOSED_SESSION_OPTIONS } )"
151+ if "enable_cpu_mem_arena" in extra_options :
152+ session_options .enable_cpu_mem_arena = extra_options ["enable_cpu_mem_arena" ]
153+
116154 def load_onnx_model (self ) -> None :
117155 raise NotImplementedError ("Subclasses must implement this method" )
118156
0 commit comments