Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ jobs:

- name: Run pytest
run: |
poetry run pytest
HF_TOKEN=${{ secrets.HF_TOKEN }} poetry run pytest
41 changes: 26 additions & 15 deletions fastembed/text/multitask_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.types import NumpyArray
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker
Expand Down Expand Up @@ -44,9 +45,11 @@ class JinaEmbeddingV3(PooledNormalizedEmbedding):
PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
QUERY_TASK = Task.RETRIEVAL_QUERY

def __init__(self, *args: Any, **kwargs: Any):
def __init__(self, *args: Any, task_id: Optional[int] = None, **kwargs: Any):
super().__init__(*args, **kwargs)
self.current_task_id: Union[Task, int] = self.PASSAGE_TASK
self.default_task_id: Union[Task, int] = (
task_id if task_id is not None else self.PASSAGE_TASK
)

@classmethod
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
Expand All @@ -57,30 +60,34 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
return supported_multitask_models

def _preprocess_onnx_input(
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
self,
onnx_input: dict[str, NumpyArray],
task_id: Optional[Union[int, Task]] = None,
**kwargs: Any,
) -> dict[str, NumpyArray]:
onnx_input["task_id"] = np.array(self.current_task_id, dtype=np.int64)
if task_id is None:
raise ValueError(f"task_id must be provided for JinaEmbeddingV3, got <{task_id}>")
onnx_input["task_id"] = np.array(task_id, dtype=np.int64)
return onnx_input

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
task_id: int = PASSAGE_TASK,
task_id: Optional[int] = None,
**kwargs: Any,
) -> Iterable[NumpyArray]:
self.current_task_id = task_id
kwargs["task_id"] = task_id
yield from super().embed(documents, batch_size, parallel, **kwargs)
task_id = (
task_id if task_id is not None else self.default_task_id
) # required for multiprocessing
yield from super().embed(documents, batch_size, parallel, task_id=task_id, **kwargs)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
self.current_task_id = self.QUERY_TASK
yield from super().embed(query, **kwargs)
yield from super().embed(query, task_id=self.QUERY_TASK, **kwargs)

def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
self.current_task_id = self.PASSAGE_TASK
yield from super().embed(texts, **kwargs)
yield from super().embed(texts, task_id=self.PASSAGE_TASK, **kwargs)


class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):
Expand All @@ -90,11 +97,15 @@ def init_embedding(
cache_dir: str,
**kwargs: Any,
) -> JinaEmbeddingV3:
model = JinaEmbeddingV3(
return JinaEmbeddingV3(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
model.current_task_id = kwargs["task_id"]
return model

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
self.model: JinaEmbeddingV3 # mypy complaints `self.model` does not have `default_task_id`
for idx, batch in items:
onnx_output = self.model.onnx_embed(batch, task_id=self.model.default_task_id)
yield idx, onnx_output
2 changes: 1 addition & 1 deletion fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _embed_documents(
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self._post_process_onnx_output(self.onnx_embed(batch))
yield from self._post_process_onnx_output(self.onnx_embed(batch, **kwargs))
else:
if parallel == 0:
parallel = os.cpu_count()
Expand Down
43 changes: 19 additions & 24 deletions tests/test_text_multitask_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,25 @@ def test_single_embedding():

canonical_vector = task["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc.model

classification_embeddings = list(model.embed(documents=docs, task_id=Task.CLASSIFICATION))
classification_embeddings = np.stack(classification_embeddings, axis=0)

assert classification_embeddings.shape == (len(docs), dim)

model = TextEmbedding(model_name=model_name, task_id=Task.CLASSIFICATION)
default_embeddings = list(model.embed(documents=docs))
default_embeddings = np.stack(default_embeddings, axis=0)

assert default_embeddings.shape == (len(docs), dim)

assert np.allclose(
classification_embeddings,
default_embeddings,
atol=1e-4,
), model_desc.model
if is_ci:
delete_model_cache(model.model._model_dir)

Expand Down Expand Up @@ -140,7 +156,7 @@ def test_single_embedding_query():

canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc.model

if is_ci:
Expand Down Expand Up @@ -172,7 +188,7 @@ def test_single_embedding_passage():

canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc.model

if is_ci:
Expand Down Expand Up @@ -207,27 +223,6 @@ def test_parallel_processing(dim: int, model_name: str):
delete_model_cache(model.model._model_dir)


def test_task_assignment():
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

if is_ci and not is_manual:
pytest.skip("Skipping in CI non-manual mode")

for model_desc in JinaEmbeddingV3._list_supported_models():
# todo: once we add more models, we should not test models >1GB size locally
model_name = model_desc.model

model = TextEmbedding(model_name=model_name)

for i, task_id in enumerate(Task):
_ = list(model.embed(documents=docs, batch_size=1, task_id=i))
assert model.model.current_task_id == task_id

if is_ci:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")
Expand Down
Loading