Skip to content

Commit c8eabdb

Browse files
committed
new: improve task setter in jina v3
1 parent aa0c475 commit c8eabdb

File tree

3 files changed

+18
-31
lines changed

3 files changed

+18
-31
lines changed

fastembed/text/multitask_embedding.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55

6+
from fastembed.common.onnx_model import OnnxOutputContext
67
from fastembed.common.types import NumpyArray
78
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
89
from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker
@@ -44,9 +45,11 @@ class JinaEmbeddingV3(PooledNormalizedEmbedding):
4445
PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
4546
QUERY_TASK = Task.RETRIEVAL_QUERY
4647

47-
def __init__(self, *args: Any, **kwargs: Any):
48+
def __init__(self, *args: Any, task_id: Optional[int] = None, **kwargs: Any):
4849
super().__init__(*args, **kwargs)
49-
self.current_task_id: Union[Task, int] = self.PASSAGE_TASK
50+
self.default_task_id: Union[Task, int] = (
51+
task_id if task_id is not None else self.PASSAGE_TASK
52+
)
5053

5154
@classmethod
5255
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
@@ -59,27 +62,28 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
5962
def _preprocess_onnx_input(
6063
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
6164
) -> dict[str, NumpyArray]:
62-
onnx_input["task_id"] = np.array(self.current_task_id, dtype=np.int64)
65+
onnx_input["task_id"] = np.array(kwargs["task_id"], dtype=np.int64)
6366
return onnx_input
6467

6568
def embed(
6669
self,
6770
documents: Union[str, Iterable[str]],
6871
batch_size: int = 256,
6972
parallel: Optional[int] = None,
70-
task_id: int = PASSAGE_TASK,
73+
task_id: Optional[int] = None,
7174
**kwargs: Any,
7275
) -> Iterable[NumpyArray]:
73-
self.current_task_id = task_id
74-
kwargs["task_id"] = task_id
76+
kwargs["task_id"] = (
77+
task_id if task_id is not None else self.default_task_id
78+
) # required for multiprocessing
7579
yield from super().embed(documents, batch_size, parallel, **kwargs)
7680

7781
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
78-
self.current_task_id = self.QUERY_TASK
82+
kwargs["task_id"] = self.QUERY_TASK
7983
yield from super().embed(query, **kwargs)
8084

8185
def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
82-
self.current_task_id = self.PASSAGE_TASK
86+
kwargs["task_id"] = self.PASSAGE_TASK
8387
yield from super().embed(texts, **kwargs)
8488

8589

@@ -96,5 +100,9 @@ def init_embedding(
96100
threads=1,
97101
**kwargs,
98102
)
99-
model.current_task_id = kwargs["task_id"]
100103
return model
104+
105+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
106+
for idx, batch in items:
107+
onnx_output = self.model.onnx_embed(batch, task_id=self.model.default_task_id)
108+
yield idx, onnx_output

fastembed/text/onnx_text_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _embed_documents(
115115
if not hasattr(self, "model") or self.model is None:
116116
self.load_onnx_model()
117117
for batch in iter_batch(documents, batch_size):
118-
yield from self._post_process_onnx_output(self.onnx_embed(batch))
118+
yield from self._post_process_onnx_output(self.onnx_embed(batch, **kwargs))
119119
else:
120120
if parallel == 0:
121121
parallel = os.cpu_count()

tests/test_text_multitask_embeddings.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -207,27 +207,6 @@ def test_parallel_processing(dim: int, model_name: str):
207207
delete_model_cache(model.model._model_dir)
208208

209209

210-
def test_task_assignment():
211-
is_ci = os.getenv("CI")
212-
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
213-
214-
if is_ci and not is_manual:
215-
pytest.skip("Skipping in CI non-manual mode")
216-
217-
for model_desc in JinaEmbeddingV3._list_supported_models():
218-
# todo: once we add more models, we should not test models >1GB size locally
219-
model_name = model_desc.model
220-
221-
model = TextEmbedding(model_name=model_name)
222-
223-
for i, task_id in enumerate(Task):
224-
_ = list(model.embed(documents=docs, batch_size=1, task_id=i))
225-
assert model.model.current_task_id == task_id
226-
227-
if is_ci:
228-
delete_model_cache(model.model._model_dir)
229-
230-
231210
@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
232211
def test_lazy_load(model_name: str):
233212
is_ci = os.getenv("CI")

0 commit comments

Comments
 (0)