1- from typing import Any , Dict , Iterable , List , Optional , Sequence , Type , Union
1+ from typing import Any , Iterable , Optional , Sequence , Type , Union
22
33import numpy as np
44
168168 "model_file" : "onnx/model.onnx" ,
169169 },
170170 {
171- "model" : "akshayballal/colpali-v1.2-merged " ,
172- "dim" : 128 ,
173- "description" : "" ,
174- "license" : "mit " ,
175- "size_in_GB" : 6.08 ,
171+ "model" : "jinaai/jina-clip-v1 " ,
172+ "dim" : 768 ,
173+ "description" : "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year " ,
174+ "license" : "apache-2.0 " ,
175+ "size_in_GB" : 0.55 ,
176176 "sources" : {
177- "hf" : "akshayballal/colpali-v1.2-merged-onnx " ,
177+ "hf" : "jinaai/jina-clip-v1 " ,
178178 },
179- "additional_files" : ["model.onnx_data" ],
180- "model_file" : "model.onnx" ,
179+ "model_file" : "onnx/text_model.onnx" ,
181180 },
182181]
183182
@@ -186,12 +185,12 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]):
186185 """Implementation of the Flag Embedding model."""
187186
188187 @classmethod
189- def list_supported_models (cls ) -> List [ Dict [str , Any ]]:
188+ def list_supported_models (cls ) -> list [ dict [str , Any ]]:
190189 """
191190 Lists the supported models.
192191
193192 Returns:
194- List[Dict [str, Any]]: A list of dictionaries containing the model information.
193+ list[dict [str, Any]]: A list of dictionaries containing the model information.
195194 """
196195 return supported_onnx_models
197196
@@ -202,7 +201,7 @@ def __init__(
202201 threads : Optional [int ] = None ,
203202 providers : Optional [Sequence [OnnxProvider ]] = None ,
204203 cuda : bool = False ,
205- device_ids : Optional [List [int ]] = None ,
204+ device_ids : Optional [list [int ]] = None ,
206205 lazy_load : bool = False ,
207206 device_id : Optional [int ] = None ,
208207 ** kwargs ,
@@ -218,7 +217,7 @@ def __init__(
218217 Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
219218 cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
220219 Defaults to False.
221- device_ids (Optional[List [int]], optional): The list of device ids to use for data parallel processing in
220+ device_ids (Optional[list [int]], optional): The list of device ids to use for data parallel processing in
222221 workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
223222 lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
224223 Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
@@ -291,16 +290,22 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
291290 return OnnxTextEmbeddingWorker
292291
293292 def _preprocess_onnx_input (
294- self , onnx_input : Dict [str , np .ndarray ], ** kwargs
295- ) -> Dict [str , np .ndarray ]:
293+ self , onnx_input : dict [str , np .ndarray ], ** kwargs
294+ ) -> dict [str , np .ndarray ]:
296295 """
297296 Preprocess the onnx input.
298297 """
299298 return onnx_input
300299
301300 def _post_process_onnx_output (self , output : OnnxOutputContext ) -> Iterable [np .ndarray ]:
302301 embeddings = output .model_output
303- return normalize (embeddings [:, 0 ]).astype (np .float32 )
302+ if embeddings .ndim == 3 : # (batch_size, seq_len, embedding_dim)
303+ processed_embeddings = embeddings [:, 0 ]
304+ elif embeddings .ndim == 2 : # (batch_size, embedding_dim)
305+ processed_embeddings = embeddings
306+ else :
307+ raise ValueError (f"Unsupported embedding shape: { embeddings .shape } " )
308+ return normalize (processed_embeddings ).astype (np .float32 )
304309
305310 def load_onnx_model (self ) -> None :
306311 self ._load_onnx_model (
0 commit comments