diff --git a/fastembed/common/utils.py b/fastembed/common/utils.py index 19c0efef..02ff615b 100644 --- a/fastembed/common/utils.py +++ b/fastembed/common/utils.py @@ -8,6 +8,7 @@ from typing import Iterable, Optional, TypeVar import numpy as np +from numpy.typing import NDArray from fastembed.common.types import NumpyArray @@ -22,6 +23,15 @@ def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e return normalized_array +def mean_pooling(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray: + input_mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(np.int64) + input_mask_expanded = np.tile(input_mask_expanded, (1, 1, input_array.shape[-1])) + sum_embeddings = np.sum(input_array * input_mask_expanded, axis=1) + sum_mask = np.sum(input_mask_expanded, axis=1) + pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9) + return pooled_embeddings + + def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]: """ >>> list(iter_batch([1,2,3,4,5], 3)) diff --git a/fastembed/text/clip_embedding.py b/fastembed/text/clip_embedding.py index 5c8c7829..686d3e7e 100644 --- a/fastembed/text/clip_embedding.py +++ b/fastembed/text/clip_embedding.py @@ -22,8 +22,6 @@ class CLIPOnnxEmbedding(OnnxTextEmbedding): - CUSTOM_MODELS: list[DenseModelDescription] = [] - @classmethod def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: return CLIPEmbeddingWorker @@ -35,7 +33,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]: Returns: list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ - return supported_clip_models + cls.CUSTOM_MODELS + return supported_clip_models def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]: return output.model_output diff --git a/fastembed/text/custom_text_embedding.py b/fastembed/text/custom_text_embedding.py new file mode 100644 index 00000000..13ec9ed1 --- /dev/null +++ b/fastembed/text/custom_text_embedding.py @@ -0,0 +1,91 @@ +from typing import Optional, Sequence, Any, Iterable + +from dataclasses import dataclass + +import numpy as np +from numpy.typing import NDArray + +from fastembed.common import OnnxProvider +from fastembed.common.model_description import ( + PoolingType, + DenseModelDescription, +) +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.types import NumpyArray +from fastembed.common.utils import normalize, mean_pooling +from fastembed.text.onnx_embedding import OnnxTextEmbedding + + +@dataclass(frozen=True) +class PostprocessingConfig: + pooling: PoolingType + normalization: bool + + +class CustomTextEmbedding(OnnxTextEmbedding): + SUPPORTED_MODELS: list[DenseModelDescription] = [] + POSTPROCESSING_MAPPING: dict[str, PostprocessingConfig] = {} + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + specific_model_path: Optional[str] = None, + **kwargs: Any, + ): + super().__init__( + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + providers=providers, + cuda=cuda, + device_ids=device_ids, + lazy_load=lazy_load, + device_id=device_id, + specific_model_path=specific_model_path, + **kwargs, + ) + self._pooling = self.POSTPROCESSING_MAPPING[model_name].pooling + self._normalization = self.POSTPROCESSING_MAPPING[model_name].normalization + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + return cls.SUPPORTED_MODELS + + def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]: + return self._normalize(self._pool(output.model_output, output.attention_mask)) + + def _pool( + self, embeddings: NumpyArray, attention_mask: Optional[NDArray[np.int64]] = None + ) -> NumpyArray: + if self._pooling == PoolingType.CLS: + return embeddings[:, 0] + + if self._pooling == PoolingType.MEAN: + if attention_mask is None: + raise ValueError("attention_mask must be provided for mean pooling") + return mean_pooling(embeddings, attention_mask) + + if self._pooling == PoolingType.DISABLED: + return embeddings + + def _normalize(self, embeddings: NumpyArray) -> NumpyArray: + return normalize(embeddings) if self._normalization else embeddings + + @classmethod + def add_model( + cls, + model_description: DenseModelDescription, + pooling: PoolingType, + normalization: bool, + ) -> None: + cls.SUPPORTED_MODELS.append(model_description) + cls.POSTPROCESSING_MAPPING[model_description.model] = PostprocessingConfig( + pooling=pooling, normalization=normalization + ) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index 2fe7fc69..a67a337a 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -43,7 +43,6 @@ class Task(int, Enum): class JinaEmbeddingV3(PooledNormalizedEmbedding): PASSAGE_TASK = Task.RETRIEVAL_PASSAGE QUERY_TASK = Task.RETRIEVAL_QUERY - CUSTOM_MODELS: list[DenseModelDescription] = [] def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -55,7 +54,7 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: @classmethod def _list_supported_models(cls) -> list[DenseModelDescription]: - return supported_multitask_models + cls.CUSTOM_MODELS + return supported_multitask_models def _preprocess_onnx_input( self, onnx_input: dict[str, NumpyArray], **kwargs: Any diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index cd53dc3b..bed3ca1e 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -195,35 +195,6 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[NumpyArray]): """Implementation of the Flag Embedding model.""" - CUSTOM_MODELS: list[DenseModelDescription] = [] - - @classmethod - def add_custom_model( - cls, - model: str, - sources: ModelSource, - model_file: str, - dim: int, - description: str, - license: str, - size_in_gb: float, - additional_files: Optional[list[str]] = None, - tasks: Optional[dict[str, Any]] = None, - ) -> None: - cls.CUSTOM_MODELS.append( - DenseModelDescription( - model=model, - sources=sources, - dim=dim, - model_file=model_file, - description=description, - license=license, - size_in_GB=size_in_gb, - additional_files=additional_files if additional_files else [], - tasks=tasks if tasks else {}, - ) - ) - @classmethod def _list_supported_models(cls) -> list[DenseModelDescription]: """ @@ -232,7 +203,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]: Returns: list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ - return supported_onnx_models + cls.CUSTOM_MODELS + return supported_onnx_models def __init__( self, diff --git a/fastembed/text/pooled_embedding.py b/fastembed/text/pooled_embedding.py index f4598ab4..1dc8c9f5 100644 --- a/fastembed/text/pooled_embedding.py +++ b/fastembed/text/pooled_embedding.py @@ -1,9 +1,11 @@ from typing import Any, Iterable, Type import numpy as np +from numpy.typing import NDArray from fastembed.common.types import NumpyArray from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.utils import mean_pooling from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker from fastembed.common.model_description import DenseModelDescription, ModelSource @@ -88,23 +90,15 @@ class PooledEmbedding(OnnxTextEmbedding): - CUSTOM_MODELS: list[DenseModelDescription] = [] - @classmethod def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: return PooledEmbeddingWorker @classmethod - def mean_pooling(cls, model_output: NumpyArray, attention_mask: NumpyArray) -> NumpyArray: - token_embeddings = model_output.astype(np.float32) - attention_mask = attention_mask.astype(np.float32) - input_mask_expanded = np.expand_dims(attention_mask, axis=-1) - input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1])) - input_mask_expanded = input_mask_expanded.astype(np.float32) - sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) - sum_mask = np.sum(input_mask_expanded, axis=1) - pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9) - return pooled_embeddings + def mean_pooling( + cls, model_output: NumpyArray, attention_mask: NDArray[np.int64] + ) -> NumpyArray: + return mean_pooling(model_output, attention_mask) @classmethod def _list_supported_models(cls) -> list[DenseModelDescription]: @@ -113,7 +107,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]: Returns: list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ - return supported_pooled_models + cls.CUSTOM_MODELS + return supported_pooled_models def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]: if output.attention_mask is None: diff --git a/fastembed/text/pooled_normalized_embedding.py b/fastembed/text/pooled_normalized_embedding.py index 0650025f..f0b58b64 100644 --- a/fastembed/text/pooled_normalized_embedding.py +++ b/fastembed/text/pooled_normalized_embedding.py @@ -113,8 +113,6 @@ class PooledNormalizedEmbedding(PooledEmbedding): - CUSTOM_MODELS: list[DenseModelDescription] = [] - @classmethod def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: return PooledNormalizedEmbeddingWorker @@ -126,7 +124,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]: Returns: list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ - return supported_pooled_normalized_models + cls.CUSTOM_MODELS + return supported_pooled_normalized_models def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]: if output.attention_mask is None: diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index fe744ad8..c94a5883 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -4,6 +4,7 @@ from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.text.clip_embedding import CLIPOnnxEmbedding +from fastembed.text.custom_text_embedding import CustomTextEmbedding from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding from fastembed.text.pooled_embedding import PooledEmbedding from fastembed.text.multitask_embedding import JinaEmbeddingV3 @@ -19,6 +20,7 @@ class TextEmbedding(TextEmbeddingBase): PooledNormalizedEmbedding, PooledEmbedding, JinaEmbeddingV3, + CustomTextEmbedding, ] @classmethod @@ -50,7 +52,6 @@ def add_custom_model( license: str = "", size_in_gb: float = 0.0, additional_files: Optional[list[str]] = None, - tasks: Optional[dict[str, Any]] = None, ) -> None: registered_models = cls._list_supported_models() for registered_model in registered_models: @@ -60,56 +61,20 @@ def add_custom_model( f"please use another model name" ) - if tasks: - if pooling == PoolingType.MEAN and normalization: - JinaEmbeddingV3.add_custom_model( - model=model, - sources=sources, - dim=dim, - model_file=model_file, - description=description, - license=license, - size_in_gb=size_in_gb, - additional_files=additional_files, - tasks=tasks, - ) - return None - else: - raise ValueError( - "Multitask models supported only with pooling=Pooling.MEAN and normalization=True, current values:" - f"pooling={pooling}, normalization={normalization}, tasks: {tasks}" - ) - - embedding_cls: Type[OnnxTextEmbedding] - if pooling == PoolingType.MEAN and normalization: - embedding_cls = PooledNormalizedEmbedding - elif pooling == PoolingType.MEAN and not normalization: - embedding_cls = PooledEmbedding - elif (pooling == PoolingType.CLS or PoolingType.DISABLED) and normalization: - embedding_cls = OnnxTextEmbedding - elif pooling == PoolingType.DISABLED and not normalization: - embedding_cls = CLIPOnnxEmbedding - else: - raise ValueError( - "Only the following combinations of pooling and normalization are currently supported:" - "pooling=Pooling.MEAN + normalization=True;\n" - "pooling=Pooling.MEAN + normalization=False;\n" - "pooling=Pooling.CLS + normalization=True;\n" - "pooling=Pooling.DISABLED + normalization=False;\n" - ) - - embedding_cls.add_custom_model( - model=model, - sources=sources, - dim=dim, - model_file=model_file, - description=description, - license=license, - size_in_gb=size_in_gb, - additional_files=additional_files, - tasks=tasks, + CustomTextEmbedding.add_model( + DenseModelDescription( + model=model, + sources=sources, + dim=dim, + model_file=model_file, + description=description, + license=license, + size_in_GB=size_in_gb, + additional_files=additional_files or [], + ), + pooling=pooling, + normalization=normalization, ) - return None def __init__( self, diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 4559c012..f893c57d 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -1,28 +1,21 @@ +import itertools import os import numpy as np import pytest -from dataclasses import replace from fastembed.common.model_description import PoolingType, ModelSource, DenseModelDescription -from fastembed.text.clip_embedding import CLIPOnnxEmbedding -from fastembed.text.multitask_embedding import JinaEmbeddingV3 -from fastembed.text.onnx_embedding import OnnxTextEmbedding -from fastembed.text.pooled_embedding import PooledEmbedding -from fastembed.text.text_embedding import TextEmbedding, PooledNormalizedEmbedding +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.utils import normalize, mean_pooling +from fastembed.text.custom_text_embedding import CustomTextEmbedding, PostprocessingConfig +from fastembed.text.text_embedding import TextEmbedding from tests.utils import delete_model_cache -def restore_custom_models(): - for embedding_cls in TextEmbedding.EMBEDDINGS_REGISTRY: - assert hasattr(embedding_cls, "CUSTOM_MODELS") - embedding_cls.CUSTOM_MODELS = [] - - @pytest.fixture(autouse=True) def restore_custom_models_fixture(): - restore_custom_models() + CustomTextEmbedding.SUPPORTED_MODELS = [] yield - restore_custom_models() + CustomTextEmbedding.SUPPORTED_MODELS = [] def test_text_custom_model(): @@ -46,7 +39,7 @@ def test_text_custom_model(): size_in_gb=size_in_gb, ) - assert PooledNormalizedEmbedding.CUSTOM_MODELS[0] == DenseModelDescription( + assert CustomTextEmbedding.SUPPORTED_MODELS[0] == DenseModelDescription( model=custom_model_name, sources=source, model_file="onnx/model.onnx", @@ -57,6 +50,9 @@ def test_text_custom_model(): dim=dim, tasks={}, ) + assert CustomTextEmbedding.POSTPROCESSING_MAPPING[custom_model_name] == PostprocessingConfig( + pooling=pooling, normalization=normalization + ) model = TextEmbedding(custom_model_name) docs = ["hello world", "flag embedding"] @@ -70,134 +66,66 @@ def test_text_custom_model(): def test_mock_add_custom_models(): - def check_custom_models_number(cls_to_num_map): - for embed_cls, num_models in cls_to_num_map.items(): - assert len(embed_cls.CUSTOM_MODELS) == num_models - - custom_model_name = "intfloat/multilingual-e5-small" - dim = 384 - size_in_gb = 0.47 - source = ModelSource(hf=custom_model_name) - role_model_description = DenseModelDescription( - model=custom_model_name, - sources=source, - model_file="onnx/model.onnx", - description="", - license="", - size_in_GB=size_in_gb, - additional_files=[], - dim=dim, - tasks={}, - ) - current_supported_models_number = len(TextEmbedding._list_supported_models()) - - class_num_models_map = { - PooledNormalizedEmbedding: 0, - OnnxTextEmbedding: 0, - PooledEmbedding: 0, - CLIPOnnxEmbedding: 0, - JinaEmbeddingV3: 0, + dim = 5 + size_in_gb = 0.1 + source = ModelSource(hf="artificial") + + num_tokens = 10 + dummy_pooled_embedding = np.random.random((1, dim)).astype(np.float32) + dummy_token_embedding = np.random.random((1, num_tokens, dim)).astype(np.float32) + dummy_attention_mask = np.ones((1, num_tokens)).astype(np.int64) + + dummy_token_output = OnnxOutputContext( + model_output=dummy_token_embedding, attention_mask=dummy_attention_mask + ) + dummy_pooled_output = OnnxOutputContext(model_output=dummy_pooled_embedding) + input_data = { + f"{PoolingType.MEAN.lower()}-normalized": dummy_token_output, + f"{PoolingType.MEAN.lower()}": dummy_token_output, + f"{PoolingType.CLS.lower()}-normalized": dummy_token_output, + f"{PoolingType.CLS.lower()}": dummy_token_output, + f"{PoolingType.DISABLED.lower()}-normalized": dummy_pooled_output, + f"{PoolingType.DISABLED.lower()}": dummy_pooled_output, } - check_custom_models_number(class_num_models_map) - TextEmbedding.add_custom_model( - f"{custom_model_name}-mean-normalize", - pooling=PoolingType.MEAN, - normalization=True, - sources=source, - dim=dim, - size_in_gb=size_in_gb, - ) - assert PooledNormalizedEmbedding.CUSTOM_MODELS[0] == replace( - role_model_description, model=f"{custom_model_name}-mean-normalize" - ) - class_num_models_map[PooledNormalizedEmbedding] += 1 - check_custom_models_number(class_num_models_map) - current_supported_models_number += 1 - assert len(TextEmbedding._list_supported_models()) == current_supported_models_number - - TextEmbedding.add_custom_model( - f"{custom_model_name}-cls-no-normalize", - pooling=PoolingType.CLS, - normalization=True, - sources=source, - dim=dim, - size_in_gb=size_in_gb, - ) - assert OnnxTextEmbedding.CUSTOM_MODELS[0] == replace( - role_model_description, model=f"{custom_model_name}-cls-no-normalize" - ) - class_num_models_map[OnnxTextEmbedding] += 1 - check_custom_models_number(class_num_models_map) - current_supported_models_number += 1 - assert len(TextEmbedding._list_supported_models()) == current_supported_models_number - - TextEmbedding.add_custom_model( - f"{custom_model_name}-no-pooling-normalize", - pooling=PoolingType.DISABLED, - normalization=True, - sources=source, - dim=dim, - size_in_gb=size_in_gb, - ) - assert OnnxTextEmbedding.CUSTOM_MODELS[1] == replace( - role_model_description, model=f"{custom_model_name}-no-pooling-normalize" - ) - class_num_models_map[OnnxTextEmbedding] += 1 - check_custom_models_number(class_num_models_map) - current_supported_models_number += 1 - assert len(TextEmbedding._list_supported_models()) == current_supported_models_number + expected_output = { + f"{PoolingType.MEAN.lower()}-normalized": normalize( + mean_pooling(dummy_token_embedding, dummy_attention_mask) + ).astype(np.float32), + f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask), + f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]).astype( + np.float32 + ), + f"{PoolingType.CLS.lower()}": dummy_token_embedding[:, 0], + f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding).astype( + np.float32 + ), + f"{PoolingType.DISABLED.lower()}": dummy_pooled_embedding, + } - TextEmbedding.add_custom_model( - f"{custom_model_name}-mean-no-normalize", - pooling=PoolingType.MEAN, - normalization=False, - sources=source, - dim=dim, - size_in_gb=size_in_gb, - ) - assert PooledEmbedding.CUSTOM_MODELS[0] == replace( - role_model_description, model=f"{custom_model_name}-mean-no-normalize" - ) - class_num_models_map[PooledEmbedding] += 1 - check_custom_models_number(class_num_models_map) - current_supported_models_number += 1 - assert len(TextEmbedding._list_supported_models()) == current_supported_models_number + for pooling, normalization in itertools.product( + (PoolingType.MEAN, PoolingType.CLS, PoolingType.DISABLED), (True, False) + ): + model_name = f"{pooling.name.lower()}{'-normalized' if normalization else ''}" + TextEmbedding.add_custom_model( + model_name, + pooling=pooling, + normalization=normalization, + sources=source, + dim=dim, + size_in_gb=size_in_gb, + ) - TextEmbedding.add_custom_model( - f"{custom_model_name}-no-pooling-no-normalize", - pooling=PoolingType.DISABLED, - normalization=False, - sources=source, - dim=dim, - size_in_gb=size_in_gb, - ) - assert CLIPOnnxEmbedding.CUSTOM_MODELS[0] == replace( - role_model_description, model=f"{custom_model_name}-no-pooling-no-normalize" - ) - class_num_models_map[CLIPOnnxEmbedding] += 1 - check_custom_models_number(class_num_models_map) - current_supported_models_number += 1 - assert len(TextEmbedding._list_supported_models()) == current_supported_models_number + custom_text_embedding = CustomTextEmbedding( + model_name, + lazy_load=True, + specific_model_path="./", # disable model downloading and loading + ) - TextEmbedding.add_custom_model( - f"{custom_model_name}-mean-normalize-multitask", - pooling=PoolingType.MEAN, - normalization=True, - sources=source, - dim=dim, - size_in_gb=size_in_gb, - tasks={"task1": 1}, - ) - assert JinaEmbeddingV3.CUSTOM_MODELS[0] == replace( - role_model_description, - tasks={"task1": 1}, - model=f"{custom_model_name}-mean-normalize-multitask", - ) - class_num_models_map[JinaEmbeddingV3] += 1 - check_custom_models_number(class_num_models_map) - current_supported_models_number += 1 - assert len(TextEmbedding._list_supported_models()) == current_supported_models_number + post_processed_output = next( + iter(custom_text_embedding._post_process_onnx_output(input_data[model_name])) + ) + assert np.allclose(post_processed_output, expected_output[model_name], atol=1e-3) def test_do_not_add_existing_model():