@@ -60,9 +60,14 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
6060 return supported_multitask_models
6161
6262 def _preprocess_onnx_input (
63- self , onnx_input : dict [str , NumpyArray ], ** kwargs : Any
63+ self ,
64+ onnx_input : dict [str , NumpyArray ],
65+ task_id : Optional [Union [int , Task ]] = None ,
66+ ** kwargs : Any ,
6467 ) -> dict [str , NumpyArray ]:
65- onnx_input ["task_id" ] = np .array (kwargs ["task_id" ], dtype = np .int64 )
68+ if task_id is None :
69+ raise ValueError (f"task_id must be provided for JinaEmbeddingV3, got <{ task_id } >" )
70+ onnx_input ["task_id" ] = np .array (task_id , dtype = np .int64 )
6671 return onnx_input
6772
6873 def embed (
@@ -73,18 +78,16 @@ def embed(
7378 task_id : Optional [int ] = None ,
7479 ** kwargs : Any ,
7580 ) -> Iterable [NumpyArray ]:
76- kwargs [ " task_id" ] = (
81+ task_id = (
7782 task_id if task_id is not None else self .default_task_id
7883 ) # required for multiprocessing
79- yield from super ().embed (documents , batch_size , parallel , ** kwargs )
84+ yield from super ().embed (documents , batch_size , parallel , task_id = task_id , ** kwargs )
8085
8186 def query_embed (self , query : Union [str , Iterable [str ]], ** kwargs : Any ) -> Iterable [NumpyArray ]:
82- kwargs ["task_id" ] = self .QUERY_TASK
83- yield from super ().embed (query , ** kwargs )
87+ yield from super ().embed (query , task_id = self .QUERY_TASK , ** kwargs )
8488
8589 def passage_embed (self , texts : Iterable [str ], ** kwargs : Any ) -> Iterable [NumpyArray ]:
86- kwargs ["task_id" ] = self .PASSAGE_TASK
87- yield from super ().embed (texts , ** kwargs )
90+ yield from super ().embed (texts , task_id = self .PASSAGE_TASK , ** kwargs )
8891
8992
9093class JinaEmbeddingV3Worker (OnnxTextEmbeddingWorker ):
@@ -94,15 +97,15 @@ def init_embedding(
9497 cache_dir : str ,
9598 ** kwargs : Any ,
9699 ) -> JinaEmbeddingV3 :
97- model = JinaEmbeddingV3 (
100+ return JinaEmbeddingV3 (
98101 model_name = model_name ,
99102 cache_dir = cache_dir ,
100103 threads = 1 ,
101104 ** kwargs ,
102105 )
103- return model
104106
105107 def process (self , items : Iterable [tuple [int , Any ]]) -> Iterable [tuple [int , OnnxOutputContext ]]:
108+ self .model : JinaEmbeddingV3 # mypy complaints `self.model` does not have `default_task_id`
106109 for idx , batch in items :
107110 onnx_output = self .model .onnx_embed (batch , task_id = self .model .default_task_id )
108111 yield idx , onnx_output
0 commit comments