Skip to content

Commit 9631970

Browse files
committed
Bump version
1 parent c0f6a5b commit 9631970

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

fastembed/text/onnx_embedding.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union
1+
from typing import Any, Iterable, Optional, Sequence, Type, Union
22

33
import numpy as np
44

@@ -168,16 +168,15 @@
168168
"model_file": "onnx/model.onnx",
169169
},
170170
{
171-
"model": "akshayballal/colpali-v1.2-merged",
172-
"dim": 128,
173-
"description": "",
174-
"license": "mit",
175-
"size_in_GB": 6.08,
171+
"model": "jinaai/jina-clip-v1",
172+
"dim": 768,
173+
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
174+
"license": "apache-2.0",
175+
"size_in_GB": 0.55,
176176
"sources": {
177-
"hf": "akshayballal/colpali-v1.2-merged-onnx",
177+
"hf": "jinaai/jina-clip-v1",
178178
},
179-
"additional_files": ["model.onnx_data"],
180-
"model_file": "model.onnx",
179+
"model_file": "onnx/text_model.onnx",
181180
},
182181
]
183182

@@ -186,12 +185,12 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]):
186185
"""Implementation of the Flag Embedding model."""
187186

188187
@classmethod
189-
def list_supported_models(cls) -> List[Dict[str, Any]]:
188+
def list_supported_models(cls) -> list[dict[str, Any]]:
190189
"""
191190
Lists the supported models.
192191
193192
Returns:
194-
List[Dict[str, Any]]: A list of dictionaries containing the model information.
193+
list[dict[str, Any]]: A list of dictionaries containing the model information.
195194
"""
196195
return supported_onnx_models
197196

@@ -202,7 +201,7 @@ def __init__(
202201
threads: Optional[int] = None,
203202
providers: Optional[Sequence[OnnxProvider]] = None,
204203
cuda: bool = False,
205-
device_ids: Optional[List[int]] = None,
204+
device_ids: Optional[list[int]] = None,
206205
lazy_load: bool = False,
207206
device_id: Optional[int] = None,
208207
**kwargs,
@@ -218,7 +217,7 @@ def __init__(
218217
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
219218
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
220219
Defaults to False.
221-
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
220+
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
222221
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
223222
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
224223
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
@@ -291,16 +290,22 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
291290
return OnnxTextEmbeddingWorker
292291

293292
def _preprocess_onnx_input(
294-
self, onnx_input: Dict[str, np.ndarray], **kwargs
295-
) -> Dict[str, np.ndarray]:
293+
self, onnx_input: dict[str, np.ndarray], **kwargs
294+
) -> dict[str, np.ndarray]:
296295
"""
297296
Preprocess the onnx input.
298297
"""
299298
return onnx_input
300299

301300
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
302301
embeddings = output.model_output
303-
return normalize(embeddings[:, 0]).astype(np.float32)
302+
if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
303+
processed_embeddings = embeddings[:, 0]
304+
elif embeddings.ndim == 2: # (batch_size, embedding_dim)
305+
processed_embeddings = embeddings
306+
else:
307+
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
308+
return normalize(processed_embeddings).astype(np.float32)
304309

305310
def load_onnx_model(self) -> None:
306311
self._load_onnx_model(

0 commit comments

Comments
 (0)