Skip to content

Commit ec0e312

Browse files
authored
new: expose some onnx session options (#578)
* new: expose some onnx session options * fix: fix extra session options is None case * fix: fix missing params * new: add tests
1 parent 44e3329 commit ec0e312

18 files changed

+155
-2
lines changed

fastembed/common/onnx_model.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class OnnxOutputContext:
2424

2525

2626
class OnnxModel(Generic[T]):
27+
EXPOSED_SESSION_OPTIONS = ("enable_cpu_mem_arena",)
28+
2729
@classmethod
2830
def _get_worker_class(cls) -> Type["EmbeddingWorker[T]"]:
2931
raise NotImplementedError("Subclasses must implement this method")
@@ -60,6 +62,7 @@ def _load_onnx_model(
6062
providers: Optional[Sequence[OnnxProvider]] = None,
6163
cuda: bool = False,
6264
device_id: Optional[int] = None,
65+
extra_session_options: Optional[dict[str, Any]] = None,
6366
) -> None:
6467
model_path = model_dir / model_file
6568
# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
@@ -99,6 +102,9 @@ def _load_onnx_model(
99102
so.intra_op_num_threads = threads
100103
so.inter_op_num_threads = threads
101104

105+
if extra_session_options is not None:
106+
self.add_extra_session_options(so, extra_session_options)
107+
102108
self.model = ort.InferenceSession(
103109
str(model_path), providers=onnx_providers, sess_options=so
104110
)
@@ -113,6 +119,38 @@ def _load_onnx_model(
113119
RuntimeWarning,
114120
)
115121

122+
@classmethod
123+
def _select_exposed_session_options(cls, model_kwargs: dict[str, Any]) -> dict[str, Any]:
124+
"""A convenience method to select the exposed session options in models
125+
126+
Args:
127+
model_kwargs (dict[str, Any]): The model kwargs.
128+
129+
Returns:
130+
dict[str, Any]: a dict with filtered exposed session options.
131+
"""
132+
return {k: v for k, v in model_kwargs.items() if k in cls.EXPOSED_SESSION_OPTIONS}
133+
134+
@classmethod
135+
def add_extra_session_options(
136+
cls, session_options: ort.SessionOptions, extra_options: dict[str, Any]
137+
) -> None:
138+
"""Add extra session options to the existing options object in-place
139+
140+
Args:
141+
session_options (ort.SessionOptions): The existing session options object.
142+
extra_options (dict[str, Any]): The extra session options available in cls.EXPOSED_SESSION_OPTIONS.
143+
144+
Returns:
145+
None
146+
"""
147+
for option in extra_options:
148+
assert (
149+
option in cls.EXPOSED_SESSION_OPTIONS
150+
), f"{option} is unknown or not exposed (exposed options: {cls.EXPOSED_SESSION_OPTIONS})"
151+
if "enable_cpu_mem_arena" in extra_options:
152+
session_options.enable_cpu_mem_arena = extra_options["enable_cpu_mem_arena"]
153+
116154
def load_onnx_model(self) -> None:
117155
raise NotImplementedError("Subclasses must implement this method")
118156

fastembed/image/onnx_embedding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
super().__init__(model_name, cache_dir, threads, **kwargs)
9999
self.providers = providers
100100
self.lazy_load = lazy_load
101+
self._extra_session_options = self._select_exposed_session_options(kwargs)
101102

102103
# List of device ids, that can be used for data parallel processing in workers
103104
self.device_ids = device_ids
@@ -134,6 +135,7 @@ def load_onnx_model(self) -> None:
134135
providers=self.providers,
135136
cuda=self.cuda,
136137
device_id=self.device_id,
138+
extra_session_options=self._extra_session_options,
137139
)
138140

139141
@classmethod
@@ -180,6 +182,7 @@ def embed(
180182
device_ids=self.device_ids,
181183
local_files_only=self._local_files_only,
182184
specific_model_path=self._specific_model_path,
185+
extra_session_options=self._extra_session_options,
183186
**kwargs,
184187
)
185188

fastembed/image/onnx_image_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _load_onnx_model(
5555
providers: Optional[Sequence[OnnxProvider]] = None,
5656
cuda: bool = False,
5757
device_id: Optional[int] = None,
58+
extra_session_options: Optional[dict[str, Any]] = None,
5859
) -> None:
5960
super()._load_onnx_model(
6061
model_dir=model_dir,
@@ -63,6 +64,7 @@ def _load_onnx_model(
6364
providers=providers,
6465
cuda=cuda,
6566
device_id=device_id,
67+
extra_session_options=extra_session_options,
6668
)
6769
self.processor = load_preprocessor(model_dir=model_dir)
6870

@@ -99,6 +101,7 @@ def _embed_images(
99101
device_ids: Optional[list[int]] = None,
100102
local_files_only: bool = False,
101103
specific_model_path: Optional[str] = None,
104+
extra_session_options: Optional[dict[str, Any]] = None,
102105
**kwargs: Any,
103106
) -> Iterable[T]:
104107
is_small = False
@@ -130,6 +133,9 @@ def _embed_images(
130133
**kwargs,
131134
}
132135

136+
if extra_session_options is not None:
137+
params.update(extra_session_options)
138+
133139
pool = ParallelWorkerPool(
134140
num_workers=parallel or 1,
135141
worker=self._get_worker_class(),

fastembed/late_interaction/colbert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def __init__(
143143
super().__init__(model_name, cache_dir, threads, **kwargs)
144144
self.providers = providers
145145
self.lazy_load = lazy_load
146+
self._extra_session_options = self._select_exposed_session_options(kwargs)
146147

147148
# List of device ids, that can be used for data parallel processing in workers
148149
self.device_ids = device_ids
@@ -182,6 +183,7 @@ def load_onnx_model(self) -> None:
182183
providers=self.providers,
183184
cuda=self.cuda,
184185
device_id=self.device_id,
186+
extra_session_options=self._extra_session_options,
185187
)
186188
self.query_tokenizer, _ = load_tokenizer(model_dir=self._model_dir)
187189

@@ -235,6 +237,7 @@ def embed(
235237
device_ids=self.device_ids,
236238
local_files_only=self._local_files_only,
237239
specific_model_path=self._specific_model_path,
240+
extra_session_options=self._extra_session_options,
238241
**kwargs,
239242
)
240243

fastembed/late_interaction_multimodal/colpali.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
super().__init__(model_name, cache_dir, threads, **kwargs)
8181
self.providers = providers
8282
self.lazy_load = lazy_load
83+
self._extra_session_options = self._select_exposed_session_options(kwargs)
8384

8485
# List of device ids, that can be used for data parallel processing in workers
8586
self.device_ids = device_ids
@@ -125,6 +126,7 @@ def load_onnx_model(self) -> None:
125126
providers=self.providers,
126127
cuda=self.cuda,
127128
device_id=self.device_id,
129+
extra_session_options=self._extra_session_options,
128130
)
129131

130132
def _post_process_onnx_image_output(
@@ -238,6 +240,7 @@ def embed_text(
238240
device_ids=self.device_ids,
239241
local_files_only=self._local_files_only,
240242
specific_model_path=self._specific_model_path,
243+
extra_session_options=self._extra_session_options,
241244
**kwargs,
242245
)
243246

@@ -273,6 +276,7 @@ def embed_image(
273276
device_ids=self.device_ids,
274277
local_files_only=self._local_files_only,
275278
specific_model_path=self._specific_model_path,
279+
extra_session_options=self._extra_session_options,
276280
**kwargs,
277281
)
278282

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def _load_onnx_model(
6464
providers: Optional[Sequence[OnnxProvider]] = None,
6565
cuda: bool = False,
6666
device_id: Optional[int] = None,
67+
extra_session_options: Optional[dict[str, Any]] = None,
6768
) -> None:
6869
super()._load_onnx_model(
6970
model_dir=model_dir,
@@ -72,6 +73,7 @@ def _load_onnx_model(
7273
providers=providers,
7374
cuda=cuda,
7475
device_id=device_id,
76+
extra_session_options=extra_session_options,
7577
)
7678
self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir)
7779
assert self.tokenizer is not None
@@ -122,6 +124,7 @@ def _embed_documents(
122124
device_ids: Optional[list[int]] = None,
123125
local_files_only: bool = False,
124126
specific_model_path: Optional[str] = None,
127+
extra_session_options: Optional[dict[str, Any]] = None,
125128
**kwargs: Any,
126129
) -> Iterable[T]:
127130
is_small = False
@@ -153,6 +156,9 @@ def _embed_documents(
153156
**kwargs,
154157
}
155158

159+
if extra_session_options is not None:
160+
params.update(extra_session_options)
161+
156162
pool = ParallelWorkerPool(
157163
num_workers=parallel or 1,
158164
worker=self._get_text_worker_class(),
@@ -189,6 +195,7 @@ def _embed_images(
189195
device_ids: Optional[list[int]] = None,
190196
local_files_only: bool = False,
191197
specific_model_path: Optional[str] = None,
198+
extra_session_options: Optional[dict[str, Any]] = None,
192199
**kwargs: Any,
193200
) -> Iterable[T]:
194201
is_small = False
@@ -220,6 +227,9 @@ def _embed_images(
220227
**kwargs,
221228
}
222229

230+
if extra_session_options is not None:
231+
params.update(extra_session_options)
232+
223233
pool = ParallelWorkerPool(
224234
num_workers=parallel or 1,
225235
worker=self._get_image_worker_class(),

fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
super().__init__(model_name, cache_dir, threads, **kwargs)
112112
self.providers = providers
113113
self.lazy_load = lazy_load
114+
self._extra_session_options = self._select_exposed_session_options(kwargs)
114115

115116
# List of device ids, that can be used for data parallel processing in workers
116117
self.device_ids = device_ids
@@ -150,6 +151,7 @@ def load_onnx_model(self) -> None:
150151
providers=self.providers,
151152
cuda=self.cuda,
152153
device_id=self.device_id,
154+
extra_session_options=self._extra_session_options,
153155
)
154156

155157
def rerank(
@@ -192,6 +194,7 @@ def rerank_pairs(
192194
device_ids=self.device_ids,
193195
local_files_only=self._local_files_only,
194196
specific_model_path=self._specific_model_path,
197+
extra_session_options=self._extra_session_options,
195198
**kwargs,
196199
)
197200

fastembed/rerank/cross_encoder/onnx_text_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _load_onnx_model(
3333
providers: Optional[Sequence[OnnxProvider]] = None,
3434
cuda: bool = False,
3535
device_id: Optional[int] = None,
36+
extra_session_options: Optional[dict[str, Any]] = None,
3637
) -> None:
3738
super()._load_onnx_model(
3839
model_dir=model_dir,
@@ -41,6 +42,7 @@ def _load_onnx_model(
4142
providers=providers,
4243
cuda=cuda,
4344
device_id=device_id,
45+
extra_session_options=extra_session_options,
4446
)
4547
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)
4648
assert self.tokenizer is not None
@@ -96,6 +98,7 @@ def _rerank_pairs(
9698
device_ids: Optional[list[int]] = None,
9799
local_files_only: bool = False,
98100
specific_model_path: Optional[str] = None,
101+
extra_session_options: Optional[dict[str, Any]] = None,
99102
**kwargs: Any,
100103
) -> Iterable[float]:
101104
is_small = False
@@ -127,6 +130,9 @@ def _rerank_pairs(
127130
**kwargs,
128131
}
129132

133+
if extra_session_options is not None:
134+
params.update(extra_session_options)
135+
130136
pool = ParallelWorkerPool(
131137
num_workers=parallel or 1,
132138
worker=self._get_worker_class(),

fastembed/sparse/bm42.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
super().__init__(model_name, cache_dir, threads, **kwargs)
104104
self.providers = providers
105105
self.lazy_load = lazy_load
106+
self._extra_session_options = self._select_exposed_session_options(kwargs)
106107

107108
# List of device ids, that can be used for data parallel processing in workers
108109
self.device_ids = device_ids
@@ -146,6 +147,7 @@ def load_onnx_model(self) -> None:
146147
providers=self.providers,
147148
cuda=self.cuda,
148149
device_id=self.device_id,
150+
extra_session_options=self._extra_session_options,
149151
)
150152

151153
for token, idx in self.tokenizer.get_vocab().items(): # type: ignore[union-attr]
@@ -312,6 +314,7 @@ def embed(
312314
alpha=self.alpha,
313315
local_files_only=self._local_files_only,
314316
specific_model_path=self._specific_model_path,
317+
extra_session_options=self._extra_session_options,
315318
)
316319

317320
@classmethod

fastembed/sparse/minicoil.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def __init__(
117117
self.device_ids = device_ids
118118
self.cuda = cuda
119119
self.device_id = device_id
120+
self._extra_session_options = self._select_exposed_session_options(kwargs)
121+
120122
self.k = k
121123
self.b = b
122124
self.avg_len = avg_len
@@ -153,6 +155,7 @@ def load_onnx_model(self) -> None:
153155
providers=self.providers,
154156
cuda=self.cuda,
155157
device_id=self.device_id,
158+
extra_session_options=self._extra_session_options,
156159
)
157160

158161
assert self.tokenizer is not None
@@ -221,6 +224,7 @@ def embed(
221224
is_query=False,
222225
local_files_only=self._local_files_only,
223226
specific_model_path=self._specific_model_path,
227+
extra_session_options=self._extra_session_options,
224228
**kwargs,
225229
)
226230

0 commit comments

Comments
 (0)