diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index dcbfcc02..fda03169 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -42,5 +42,7 @@ jobs: poetry install --no-interaction --no-ansi --without dev,docs - name: Run pytest + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | - poetry run pytest \ No newline at end of file + poetry run pytest \ No newline at end of file diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index a67a337a..d8b28737 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -3,6 +3,7 @@ import numpy as np +from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.types import NumpyArray from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker @@ -44,9 +45,11 @@ class JinaEmbeddingV3(PooledNormalizedEmbedding): PASSAGE_TASK = Task.RETRIEVAL_PASSAGE QUERY_TASK = Task.RETRIEVAL_QUERY - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, task_id: Optional[int] = None, **kwargs: Any): super().__init__(*args, **kwargs) - self.current_task_id: Union[Task, int] = self.PASSAGE_TASK + self.default_task_id: Union[Task, int] = ( + task_id if task_id is not None else self.PASSAGE_TASK + ) @classmethod def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: @@ -57,9 +60,14 @@ def _list_supported_models(cls) -> list[DenseModelDescription]: return supported_multitask_models def _preprocess_onnx_input( - self, onnx_input: dict[str, NumpyArray], **kwargs: Any + self, + onnx_input: dict[str, NumpyArray], + task_id: Optional[Union[int, Task]] = None, + **kwargs: Any, ) -> dict[str, NumpyArray]: - onnx_input["task_id"] = np.array(self.current_task_id, dtype=np.int64) + if task_id is None: + raise ValueError(f"task_id must be provided for JinaEmbeddingV3, got <{task_id}>") + onnx_input["task_id"] = np.array(task_id, dtype=np.int64) return onnx_input def embed( @@ -67,20 +75,19 @@ def embed( documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: Optional[int] = None, - task_id: int = PASSAGE_TASK, + task_id: Optional[int] = None, **kwargs: Any, ) -> Iterable[NumpyArray]: - self.current_task_id = task_id - kwargs["task_id"] = task_id - yield from super().embed(documents, batch_size, parallel, **kwargs) + task_id = ( + task_id if task_id is not None else self.default_task_id + ) # required for multiprocessing + yield from super().embed(documents, batch_size, parallel, task_id=task_id, **kwargs) def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]: - self.current_task_id = self.QUERY_TASK - yield from super().embed(query, **kwargs) + yield from super().embed(query, task_id=self.QUERY_TASK, **kwargs) def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]: - self.current_task_id = self.PASSAGE_TASK - yield from super().embed(texts, **kwargs) + yield from super().embed(texts, task_id=self.PASSAGE_TASK, **kwargs) class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker): @@ -90,11 +97,15 @@ def init_embedding( cache_dir: str, **kwargs: Any, ) -> JinaEmbeddingV3: - model = JinaEmbeddingV3( + return JinaEmbeddingV3( model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs, ) - model.current_task_id = kwargs["task_id"] - return model + + def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]: + self.model: JinaEmbeddingV3 # mypy complaints `self.model` does not have `default_task_id` + for idx, batch in items: + onnx_output = self.model.onnx_embed(batch, task_id=self.model.default_task_id) + yield idx, onnx_output diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 625ab6b3..506bc7b5 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -115,7 +115,7 @@ def _embed_documents( if not hasattr(self, "model") or self.model is None: self.load_onnx_model() for batch in iter_batch(documents, batch_size): - yield from self._post_process_onnx_output(self.onnx_embed(batch)) + yield from self._post_process_onnx_output(self.onnx_embed(batch, **kwargs)) else: if parallel == 0: parallel = os.cpu_count() diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index 874ffcec..f246b489 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -109,9 +109,25 @@ def test_single_embedding(): canonical_vector = task["vectors"] assert np.allclose( - embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4 ), model_desc.model + classification_embeddings = list(model.embed(documents=docs, task_id=Task.CLASSIFICATION)) + classification_embeddings = np.stack(classification_embeddings, axis=0) + + assert classification_embeddings.shape == (len(docs), dim) + + model = TextEmbedding(model_name=model_name, task_id=Task.CLASSIFICATION) + default_embeddings = list(model.embed(documents=docs)) + default_embeddings = np.stack(default_embeddings, axis=0) + + assert default_embeddings.shape == (len(docs), dim) + + assert np.allclose( + classification_embeddings, + default_embeddings, + atol=1e-4, + ), model_desc.model if is_ci: delete_model_cache(model.model._model_dir) @@ -140,7 +156,7 @@ def test_single_embedding_query(): canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] assert np.allclose( - embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4 ), model_desc.model if is_ci: @@ -172,7 +188,7 @@ def test_single_embedding_passage(): canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] assert np.allclose( - embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4 ), model_desc.model if is_ci: @@ -207,27 +223,6 @@ def test_parallel_processing(dim: int, model_name: str): delete_model_cache(model.model._model_dir) -def test_task_assignment(): - is_ci = os.getenv("CI") - is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" - - if is_ci and not is_manual: - pytest.skip("Skipping in CI non-manual mode") - - for model_desc in JinaEmbeddingV3._list_supported_models(): - # todo: once we add more models, we should not test models >1GB size locally - model_name = model_desc.model - - model = TextEmbedding(model_name=model_name) - - for i, task_id in enumerate(Task): - _ = list(model.embed(documents=docs, batch_size=1, task_id=i)) - assert model.model.current_task_id == task_id - - if is_ci: - delete_model_cache(model.model._model_dir) - - @pytest.mark.parametrize("model_name", ["jinaai/jina-embeddings-v3"]) def test_lazy_load(model_name: str): is_ci = os.getenv("CI")