33
44import numpy as np
55
6+ from fastembed .common .onnx_model import OnnxOutputContext
67from fastembed .common .types import NumpyArray
78from fastembed .text .pooled_normalized_embedding import PooledNormalizedEmbedding
89from fastembed .text .onnx_embedding import OnnxTextEmbeddingWorker
@@ -44,9 +45,11 @@ class JinaEmbeddingV3(PooledNormalizedEmbedding):
4445 PASSAGE_TASK = Task .RETRIEVAL_PASSAGE
4546 QUERY_TASK = Task .RETRIEVAL_QUERY
4647
47- def __init__ (self , * args : Any , ** kwargs : Any ):
48+ def __init__ (self , * args : Any , task_id : Optional [ int ] = None , ** kwargs : Any ):
4849 super ().__init__ (* args , ** kwargs )
49- self .current_task_id : Union [Task , int ] = self .PASSAGE_TASK
50+ self .default_task_id : Union [Task , int ] = (
51+ task_id if task_id is not None else self .PASSAGE_TASK
52+ )
5053
5154 @classmethod
5255 def _get_worker_class (cls ) -> Type [OnnxTextEmbeddingWorker ]:
@@ -59,27 +62,28 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
5962 def _preprocess_onnx_input (
6063 self , onnx_input : dict [str , NumpyArray ], ** kwargs : Any
6164 ) -> dict [str , NumpyArray ]:
62- onnx_input ["task_id" ] = np .array (self . current_task_id , dtype = np .int64 )
65+ onnx_input ["task_id" ] = np .array (kwargs [ "task_id" ] , dtype = np .int64 )
6366 return onnx_input
6467
6568 def embed (
6669 self ,
6770 documents : Union [str , Iterable [str ]],
6871 batch_size : int = 256 ,
6972 parallel : Optional [int ] = None ,
70- task_id : int = PASSAGE_TASK ,
73+ task_id : Optional [ int ] = None ,
7174 ** kwargs : Any ,
7275 ) -> Iterable [NumpyArray ]:
73- self .current_task_id = task_id
74- kwargs ["task_id" ] = task_id
76+ kwargs ["task_id" ] = (
77+ task_id if task_id is not None else self .default_task_id
78+ ) # required for multiprocessing
7579 yield from super ().embed (documents , batch_size , parallel , ** kwargs )
7680
7781 def query_embed (self , query : Union [str , Iterable [str ]], ** kwargs : Any ) -> Iterable [NumpyArray ]:
78- self . current_task_id = self .QUERY_TASK
82+ kwargs [ "task_id" ] = self .QUERY_TASK
7983 yield from super ().embed (query , ** kwargs )
8084
8185 def passage_embed (self , texts : Iterable [str ], ** kwargs : Any ) -> Iterable [NumpyArray ]:
82- self . current_task_id = self .PASSAGE_TASK
86+ kwargs [ "task_id" ] = self .PASSAGE_TASK
8387 yield from super ().embed (texts , ** kwargs )
8488
8589
@@ -96,5 +100,9 @@ def init_embedding(
96100 threads = 1 ,
97101 ** kwargs ,
98102 )
99- model .current_task_id = kwargs ["task_id" ]
100103 return model
104+
105+ def process (self , items : Iterable [tuple [int , Any ]]) -> Iterable [tuple [int , OnnxOutputContext ]]:
106+ for idx , batch in items :
107+ onnx_output = self .model .onnx_embed (batch , task_id = self .model .default_task_id )
108+ yield idx , onnx_output
0 commit comments