Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,7 @@ jobs:
poetry install --no-interaction --no-ansi --without dev,docs

- name: Run pytest
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
poetry run pytest
poetry run pytest
41 changes: 26 additions & 15 deletions fastembed/text/multitask_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.types import NumpyArray
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker
Expand Down Expand Up @@ -44,9 +45,11 @@ class JinaEmbeddingV3(PooledNormalizedEmbedding):
PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
QUERY_TASK = Task.RETRIEVAL_QUERY

def __init__(self, *args: Any, **kwargs: Any):
def __init__(self, *args: Any, task_id: Optional[int] = None, **kwargs: Any):
super().__init__(*args, **kwargs)
self.current_task_id: Union[Task, int] = self.PASSAGE_TASK
self.default_task_id: Union[Task, int] = (
task_id if task_id is not None else self.PASSAGE_TASK
)

@classmethod
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
Expand All @@ -57,30 +60,34 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
return supported_multitask_models

def _preprocess_onnx_input(
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
self,
onnx_input: dict[str, NumpyArray],
task_id: Optional[Union[int, Task]] = None,
**kwargs: Any,
) -> dict[str, NumpyArray]:
onnx_input["task_id"] = np.array(self.current_task_id, dtype=np.int64)
if task_id is None:
raise ValueError(f"task_id must be provided for JinaEmbeddingV3, got <{task_id}>")
onnx_input["task_id"] = np.array(task_id, dtype=np.int64)
return onnx_input

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
task_id: int = PASSAGE_TASK,
task_id: Optional[int] = None,
**kwargs: Any,
) -> Iterable[NumpyArray]:
self.current_task_id = task_id
kwargs["task_id"] = task_id
yield from super().embed(documents, batch_size, parallel, **kwargs)
task_id = (
task_id if task_id is not None else self.default_task_id
) # required for multiprocessing
yield from super().embed(documents, batch_size, parallel, task_id=task_id, **kwargs)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
self.current_task_id = self.QUERY_TASK
yield from super().embed(query, **kwargs)
yield from super().embed(query, task_id=self.QUERY_TASK, **kwargs)

def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
self.current_task_id = self.PASSAGE_TASK
yield from super().embed(texts, **kwargs)
yield from super().embed(texts, task_id=self.PASSAGE_TASK, **kwargs)


class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):
Expand All @@ -90,11 +97,15 @@ def init_embedding(
cache_dir: str,
**kwargs: Any,
) -> JinaEmbeddingV3:
model = JinaEmbeddingV3(
return JinaEmbeddingV3(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
model.current_task_id = kwargs["task_id"]
return model

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
self.model: JinaEmbeddingV3 # mypy complaints `self.model` does not have `default_task_id`
for idx, batch in items:
onnx_output = self.model.onnx_embed(batch, task_id=self.model.default_task_id)
yield idx, onnx_output
2 changes: 1 addition & 1 deletion fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _embed_documents(
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()
for batch in iter_batch(documents, batch_size):
yield from self._post_process_onnx_output(self.onnx_embed(batch))
yield from self._post_process_onnx_output(self.onnx_embed(batch, **kwargs))
else:
if parallel == 0:
parallel = os.cpu_count()
Expand Down
43 changes: 19 additions & 24 deletions tests/test_text_multitask_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,25 @@ def test_single_embedding():

canonical_vector = task["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc.model

classification_embeddings = list(model.embed(documents=docs, task_id=Task.CLASSIFICATION))
classification_embeddings = np.stack(classification_embeddings, axis=0)

assert classification_embeddings.shape == (len(docs), dim)

model = TextEmbedding(model_name=model_name, task_id=Task.CLASSIFICATION)
default_embeddings = list(model.embed(documents=docs))
default_embeddings = np.stack(default_embeddings, axis=0)

assert default_embeddings.shape == (len(docs), dim)

assert np.allclose(
classification_embeddings,
default_embeddings,
atol=1e-4,
), model_desc.model
if is_ci:
delete_model_cache(model.model._model_dir)

Expand Down Expand Up @@ -140,7 +156,7 @@ def test_single_embedding_query():

canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc.model

if is_ci:
Expand Down Expand Up @@ -172,7 +188,7 @@ def test_single_embedding_passage():

canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc.model

if is_ci:
Expand Down Expand Up @@ -207,27 +223,6 @@ def test_parallel_processing(dim: int, model_name: str):
delete_model_cache(model.model._model_dir)


def test_task_assignment():
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

if is_ci and not is_manual:
pytest.skip("Skipping in CI non-manual mode")

for model_desc in JinaEmbeddingV3._list_supported_models():
# todo: once we add more models, we should not test models >1GB size locally
model_name = model_desc.model

model = TextEmbedding(model_name=model_name)

for i, task_id in enumerate(Task):
_ = list(model.embed(documents=docs, batch_size=1, task_id=i))
assert model.model.current_task_id == task_id

if is_ci:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"])
def test_lazy_load(model_name: str):
is_ci = os.getenv("CI")
Expand Down