Skip to content

Commit a5a5cea

Browse files
committed
HF sources for all models
1 parent e89654d commit a5a5cea

File tree

1 file changed

+16
-21
lines changed

1 file changed

+16
-21
lines changed

fastembed/text/onnx_embedding.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Iterable, Optional, Sequence, Type, Union
1+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union
22

33
import numpy as np
44

@@ -168,15 +168,16 @@
168168
"model_file": "onnx/model.onnx",
169169
},
170170
{
171-
"model": "jinaai/jina-clip-v1",
172-
"dim": 768,
173-
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
174-
"license": "apache-2.0",
175-
"size_in_GB": 0.55,
171+
"model": "akshayballal/colpali-v1.2-merged",
172+
"dim": 128,
173+
"description": "",
174+
"license": "mit",
175+
"size_in_GB": 6.08,
176176
"sources": {
177-
"hf": "jinaai/jina-clip-v1",
177+
"hf": "akshayballal/colpali-v1.2-merged-onnx",
178178
},
179-
"model_file": "onnx/text_model.onnx",
179+
"additional_files": ["model.onnx_data"],
180+
"model_file": "model.onnx",
180181
},
181182
]
182183

@@ -185,12 +186,12 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]):
185186
"""Implementation of the Flag Embedding model."""
186187

187188
@classmethod
188-
def list_supported_models(cls) -> list[dict[str, Any]]:
189+
def list_supported_models(cls) -> List[Dict[str, Any]]:
189190
"""
190191
Lists the supported models.
191192
192193
Returns:
193-
list[dict[str, Any]]: A list of dictionaries containing the model information.
194+
List[Dict[str, Any]]: A list of dictionaries containing the model information.
194195
"""
195196
return supported_onnx_models
196197

@@ -201,7 +202,7 @@ def __init__(
201202
threads: Optional[int] = None,
202203
providers: Optional[Sequence[OnnxProvider]] = None,
203204
cuda: bool = False,
204-
device_ids: Optional[list[int]] = None,
205+
device_ids: Optional[List[int]] = None,
205206
lazy_load: bool = False,
206207
device_id: Optional[int] = None,
207208
**kwargs,
@@ -217,7 +218,7 @@ def __init__(
217218
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
218219
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
219220
Defaults to False.
220-
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
221+
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
221222
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
222223
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
223224
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
@@ -290,22 +291,16 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
290291
return OnnxTextEmbeddingWorker
291292

292293
def _preprocess_onnx_input(
293-
self, onnx_input: dict[str, np.ndarray], **kwargs
294-
) -> dict[str, np.ndarray]:
294+
self, onnx_input: Dict[str, np.ndarray], **kwargs
295+
) -> Dict[str, np.ndarray]:
295296
"""
296297
Preprocess the onnx input.
297298
"""
298299
return onnx_input
299300

300301
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
301302
embeddings = output.model_output
302-
if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
303-
processed_embeddings = embeddings[:, 0]
304-
elif embeddings.ndim == 2: # (batch_size, embedding_dim)
305-
processed_embeddings = embeddings
306-
else:
307-
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
308-
return normalize(processed_embeddings).astype(np.float32)
303+
return normalize(embeddings[:, 0]).astype(np.float32)
309304

310305
def load_onnx_model(self) -> None:
311306
self._load_onnx_model(

0 commit comments

Comments
 (0)