Skip to content

Commit a260022

Browse files
authored
fix: propagate local files only and specific model path into embed parallel (#524)
1 parent d5da562 commit a260022

File tree

14 files changed

+67
-10
lines changed

14 files changed

+67
-10
lines changed

fastembed/image/onnx_embedding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,12 @@ def __init__(
112112

113113
self.model_description = self._get_model_description(model_name)
114114
self.cache_dir = str(define_cache_dir(cache_dir))
115+
self._specific_model_path = specific_model_path
115116
self._model_dir = self.download_model(
116117
self.model_description,
117118
self.cache_dir,
118119
local_files_only=self._local_files_only,
119-
specific_model_path=specific_model_path,
120+
specific_model_path=self._specific_model_path,
120121
)
121122

122123
if not self.lazy_load:
@@ -177,6 +178,8 @@ def embed(
177178
providers=self.providers,
178179
cuda=self.cuda,
179180
device_ids=self.device_ids,
181+
local_files_only=self._local_files_only,
182+
specific_model_path=self._specific_model_path,
180183
**kwargs,
181184
)
182185

fastembed/image/onnx_image_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def _embed_images(
9797
providers: Optional[Sequence[OnnxProvider]] = None,
9898
cuda: bool = False,
9999
device_ids: Optional[list[int]] = None,
100+
local_files_only: bool = False,
101+
specific_model_path: Optional[str] = None,
100102
**kwargs: Any,
101103
) -> Iterable[T]:
102104
is_small = False
@@ -123,6 +125,8 @@ def _embed_images(
123125
"model_name": model_name,
124126
"cache_dir": cache_dir,
125127
"providers": providers,
128+
"local_files_only": local_files_only,
129+
"specific_model_path": specific_model_path,
126130
**kwargs,
127131
}
128132

fastembed/late_interaction/colbert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,12 @@ def __init__(
169169
self.model_description = self._get_model_description(model_name)
170170
self.cache_dir = str(define_cache_dir(cache_dir))
171171

172+
self._specific_model_path = specific_model_path
172173
self._model_dir = self.download_model(
173174
self.model_description,
174175
self.cache_dir,
175176
local_files_only=self._local_files_only,
176-
specific_model_path=specific_model_path,
177+
specific_model_path=self._specific_model_path,
177178
)
178179
self.mask_token_id: Optional[int] = None
179180
self.pad_token_id: Optional[int] = None
@@ -233,6 +234,8 @@ def embed(
233234
providers=self.providers,
234235
cuda=self.cuda,
235236
device_ids=self.device_ids,
237+
local_files_only=self._local_files_only,
238+
specific_model_path=self._specific_model_path,
236239
**kwargs,
237240
)
238241

fastembed/late_interaction/token_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010
from fastembed.text.onnx_embedding import OnnxTextEmbedding
1111
from fastembed.text.onnx_text_model import TextEmbeddingWorker
12-
import numpy as np
12+
1313

1414
supported_token_embeddings_models = [
1515
DenseModelDescription(

fastembed/late_interaction_multimodal/colpali.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,12 @@ def __init__(
9595
self.model_description = self._get_model_description(model_name)
9696
self.cache_dir = str(define_cache_dir(cache_dir))
9797

98+
self._specific_model_path = specific_model_path
9899
self._model_dir = self.download_model(
99100
self.model_description,
100101
self.cache_dir,
101102
local_files_only=self._local_files_only,
102-
specific_model_path=specific_model_path,
103+
specific_model_path=self._specific_model_path,
103104
)
104105
self.mask_token_id = None
105106
self.pad_token_id = None
@@ -235,6 +236,8 @@ def embed_text(
235236
providers=self.providers,
236237
cuda=self.cuda,
237238
device_ids=self.device_ids,
239+
local_files_only=self._local_files_only,
240+
specific_model_path=self._specific_model_path,
238241
**kwargs,
239242
)
240243

@@ -268,6 +271,8 @@ def embed_image(
268271
providers=self.providers,
269272
cuda=self.cuda,
270273
device_ids=self.device_ids,
274+
local_files_only=self._local_files_only,
275+
specific_model_path=self._specific_model_path,
271276
**kwargs,
272277
)
273278

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def _embed_documents(
120120
providers: Optional[Sequence[OnnxProvider]] = None,
121121
cuda: bool = False,
122122
device_ids: Optional[list[int]] = None,
123+
local_files_only: bool = False,
124+
specific_model_path: Optional[str] = None,
123125
**kwargs: Any,
124126
) -> Iterable[T]:
125127
is_small = False
@@ -146,6 +148,8 @@ def _embed_documents(
146148
"model_name": model_name,
147149
"cache_dir": cache_dir,
148150
"providers": providers,
151+
"local_files_only": local_files_only,
152+
"specific_model_path": specific_model_path,
149153
**kwargs,
150154
}
151155

@@ -183,6 +187,8 @@ def _embed_images(
183187
providers: Optional[Sequence[OnnxProvider]] = None,
184188
cuda: bool = False,
185189
device_ids: Optional[list[int]] = None,
190+
local_files_only: bool = False,
191+
specific_model_path: Optional[str] = None,
186192
**kwargs: Any,
187193
) -> Iterable[T]:
188194
is_small = False
@@ -209,6 +215,8 @@ def _embed_images(
209215
"model_name": model_name,
210216
"cache_dir": cache_dir,
211217
"providers": providers,
218+
"local_files_only": local_files_only,
219+
"specific_model_path": specific_model_path,
212220
**kwargs,
213221
}
214222

fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,12 @@ def __init__(
131131

132132
self.model_description = self._get_model_description(model_name)
133133
self.cache_dir = str(define_cache_dir(cache_dir))
134+
self._specific_model_path = specific_model_path
134135
self._model_dir = self.download_model(
135136
self.model_description,
136137
self.cache_dir,
137138
local_files_only=self._local_files_only,
138-
specific_model_path=specific_model_path,
139+
specific_model_path=self._specific_model_path,
139140
)
140141

141142
if not self.lazy_load:
@@ -189,6 +190,8 @@ def rerank_pairs(
189190
providers=self.providers,
190191
cuda=self.cuda,
191192
device_ids=self.device_ids,
193+
local_files_only=self._local_files_only,
194+
specific_model_path=self._specific_model_path,
192195
**kwargs,
193196
)
194197

fastembed/rerank/cross_encoder/onnx_text_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def _rerank_pairs(
9494
providers: Optional[Sequence[OnnxProvider]] = None,
9595
cuda: bool = False,
9696
device_ids: Optional[list[int]] = None,
97+
local_files_only: bool = False,
98+
specific_model_path: Optional[str] = None,
9799
**kwargs: Any,
98100
) -> Iterable[float]:
99101
is_small = False
@@ -120,6 +122,8 @@ def _rerank_pairs(
120122
"model_name": model_name,
121123
"cache_dir": cache_dir,
122124
"providers": providers,
125+
"local_files_only": local_files_only,
126+
"specific_model_path": specific_model_path,
123127
**kwargs,
124128
}
125129

fastembed/sparse/bm25.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,12 @@ def __init__(
115115
model_description = self._get_model_description(model_name)
116116
self.cache_dir = str(define_cache_dir(cache_dir))
117117

118+
self._specific_model_path = specific_model_path
118119
self._model_dir = self.download_model(
119120
model_description,
120121
self.cache_dir,
121122
local_files_only=self._local_files_only,
122-
specific_model_path=specific_model_path,
123+
specific_model_path=self._specific_model_path,
123124
)
124125

125126
self.token_max_length = token_max_length
@@ -160,6 +161,8 @@ def _embed_documents(
160161
documents: Union[str, Iterable[str]],
161162
batch_size: int = 256,
162163
parallel: Optional[int] = None,
164+
local_files_only: bool = False,
165+
specific_model_path: Optional[str] = None,
163166
) -> Iterable[SparseEmbedding]:
164167
is_small = False
165168

@@ -188,6 +191,8 @@ def _embed_documents(
188191
"language": self.language,
189192
"token_max_length": self.token_max_length,
190193
"disable_stemmer": self.disable_stemmer,
194+
"local_files_only": local_files_only,
195+
"specific_model_path": specific_model_path,
191196
}
192197
pool = ParallelWorkerPool(
193198
num_workers=parallel or 1,
@@ -226,6 +231,8 @@ def embed(
226231
documents=documents,
227232
batch_size=batch_size,
228233
parallel=parallel,
234+
local_files_only=self._local_files_only,
235+
specific_model_path=self._specific_model_path,
229236
)
230237

231238
def _stem(self, tokens: list[str]) -> list[str]:

fastembed/sparse/bm42.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,12 @@ def __init__(
110110
self.model_description = self._get_model_description(model_name)
111111
self.cache_dir = str(define_cache_dir(cache_dir))
112112

113+
self._specific_model_path = specific_model_path
113114
self._model_dir = self.download_model(
114115
self.model_description,
115116
self.cache_dir,
116117
local_files_only=self._local_files_only,
117-
specific_model_path=specific_model_path,
118+
specific_model_path=self._specific_model_path,
118119
)
119120

120121
self.invert_vocab: dict[int, str] = {}
@@ -301,6 +302,8 @@ def embed(
301302
cuda=self.cuda,
302303
device_ids=self.device_ids,
303304
alpha=self.alpha,
305+
local_files_only=self._local_files_only,
306+
specific_model_path=self._specific_model_path,
304307
)
305308

306309
@classmethod

0 commit comments

Comments
 (0)