1- from typing import Any , Iterable , Optional , Sequence , Type , Union
1+ from typing import Any , Dict , Iterable , List , Optional , Sequence , Type , Union
22
33import numpy as np
44
168168 "model_file" : "onnx/model.onnx" ,
169169 },
170170 {
171- "model" : "jinaai/jina-clip-v1 " ,
172- "dim" : 768 ,
173- "description" : "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year " ,
174- "license" : "apache-2.0 " ,
175- "size_in_GB" : 0.55 ,
171+ "model" : "akshayballal/colpali-v1.2-merged " ,
172+ "dim" : 128 ,
173+ "description" : "" ,
174+ "license" : "mit " ,
175+ "size_in_GB" : 6.08 ,
176176 "sources" : {
177- "hf" : "jinaai/jina-clip-v1 " ,
177+ "hf" : "akshayballal/colpali-v1.2-merged-onnx " ,
178178 },
179- "model_file" : "onnx/text_model.onnx" ,
179+ "additional_files" : ["model.onnx_data" ],
180+ "model_file" : "model.onnx" ,
180181 },
181182]
182183
@@ -185,12 +186,12 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]):
185186 """Implementation of the Flag Embedding model."""
186187
187188 @classmethod
188- def list_supported_models (cls ) -> list [ dict [str , Any ]]:
189+ def list_supported_models (cls ) -> List [ Dict [str , Any ]]:
189190 """
190191 Lists the supported models.
191192
192193 Returns:
193- list[dict [str, Any]]: A list of dictionaries containing the model information.
194+ List[Dict [str, Any]]: A list of dictionaries containing the model information.
194195 """
195196 return supported_onnx_models
196197
@@ -201,7 +202,7 @@ def __init__(
201202 threads : Optional [int ] = None ,
202203 providers : Optional [Sequence [OnnxProvider ]] = None ,
203204 cuda : bool = False ,
204- device_ids : Optional [list [int ]] = None ,
205+ device_ids : Optional [List [int ]] = None ,
205206 lazy_load : bool = False ,
206207 device_id : Optional [int ] = None ,
207208 ** kwargs ,
@@ -217,7 +218,7 @@ def __init__(
217218 Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
218219 cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
219220 Defaults to False.
220- device_ids (Optional[list [int]], optional): The list of device ids to use for data parallel processing in
221+ device_ids (Optional[List [int]], optional): The list of device ids to use for data parallel processing in
221222 workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
222223 lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
223224 Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
@@ -290,22 +291,16 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
290291 return OnnxTextEmbeddingWorker
291292
292293 def _preprocess_onnx_input (
293- self , onnx_input : dict [str , np .ndarray ], ** kwargs
294- ) -> dict [str , np .ndarray ]:
294+ self , onnx_input : Dict [str , np .ndarray ], ** kwargs
295+ ) -> Dict [str , np .ndarray ]:
295296 """
296297 Preprocess the onnx input.
297298 """
298299 return onnx_input
299300
300301 def _post_process_onnx_output (self , output : OnnxOutputContext ) -> Iterable [np .ndarray ]:
301302 embeddings = output .model_output
302- if embeddings .ndim == 3 : # (batch_size, seq_len, embedding_dim)
303- processed_embeddings = embeddings [:, 0 ]
304- elif embeddings .ndim == 2 : # (batch_size, embedding_dim)
305- processed_embeddings = embeddings
306- else :
307- raise ValueError (f"Unsupported embedding shape: { embeddings .shape } " )
308- return normalize (processed_embeddings ).astype (np .float32 )
303+ return normalize (embeddings [:, 0 ]).astype (np .float32 )
309304
310305 def load_onnx_model (self ) -> None :
311306 self ._load_onnx_model (
0 commit comments