diff --git a/fastembed/common/model_description.py b/fastembed/common/model_description.py index 43e42f5f..432f8186 100644 --- a/fastembed/common/model_description.py +++ b/fastembed/common/model_description.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from enum import Enum from typing import Optional, Any @@ -28,7 +29,7 @@ class BaseModelDescription: @dataclass(frozen=True) class DenseModelDescription(BaseModelDescription): dim: Optional[int] = None - tasks: Optional[dict[str, Any]] = None + tasks: Optional[dict[str, Any]] = field(default_factory=dict) def __post_init__(self) -> None: assert self.dim is not None, "dim is required for dense model description" @@ -38,3 +39,9 @@ def __post_init__(self) -> None: class SparseModelDescription(BaseModelDescription): requires_idf: Optional[bool] = None vocab_size: Optional[int] = None + + +class PoolingType(str, Enum): + CLS = "CLS" + MEAN = "MEAN" + DISABLED = "DISABLED" diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index c16c4670..33e0994f 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -33,6 +33,31 @@ def list_supported_models(cls) -> list[dict[str, Any]]: """ raise NotImplementedError() + @classmethod + def add_custom_model( + cls, + *args: Any, + **kwargs: Any, + ) -> None: + """Add a custom model to the existing embedding classes based on the passed model descriptions + + Model description dict should contain the fields same as in one of the model descriptions presented + in fastembed.common.model_description + + E.g. for BaseModelDescription: + model: str + sources: ModelSource + model_file: str + description: str + license: str + size_in_GB: float + additional_files: list[str] + + Returns: + None + """ + raise NotImplementedError() + @classmethod def _list_supported_models(cls) -> list[T]: raise NotImplementedError() diff --git a/fastembed/common/utils.py b/fastembed/common/utils.py index 19c0efef..02ff615b 100644 --- a/fastembed/common/utils.py +++ b/fastembed/common/utils.py @@ -8,6 +8,7 @@ from typing import Iterable, Optional, TypeVar import numpy as np +from numpy.typing import NDArray from fastembed.common.types import NumpyArray @@ -22,6 +23,15 @@ def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e return normalized_array +def mean_pooling(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray: + input_mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(np.int64) + input_mask_expanded = np.tile(input_mask_expanded, (1, 1, input_array.shape[-1])) + sum_embeddings = np.sum(input_array * input_mask_expanded, axis=1) + sum_mask = np.sum(input_mask_expanded, axis=1) + pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9) + return pooled_embeddings + + def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]: """ >>> list(iter_batch([1,2,3,4,5], 3)) diff --git a/fastembed/text/custom_text_embedding.py b/fastembed/text/custom_text_embedding.py new file mode 100644 index 00000000..13ec9ed1 --- /dev/null +++ b/fastembed/text/custom_text_embedding.py @@ -0,0 +1,91 @@ +from typing import Optional, Sequence, Any, Iterable + +from dataclasses import dataclass + +import numpy as np +from numpy.typing import NDArray + +from fastembed.common import OnnxProvider +from fastembed.common.model_description import ( + PoolingType, + DenseModelDescription, +) +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.types import NumpyArray +from fastembed.common.utils import normalize, mean_pooling +from fastembed.text.onnx_embedding import OnnxTextEmbedding + + +@dataclass(frozen=True) +class PostprocessingConfig: + pooling: PoolingType + normalization: bool + + +class CustomTextEmbedding(OnnxTextEmbedding): + SUPPORTED_MODELS: list[DenseModelDescription] = [] + POSTPROCESSING_MAPPING: dict[str, PostprocessingConfig] = {} + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + specific_model_path: Optional[str] = None, + **kwargs: Any, + ): + super().__init__( + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + providers=providers, + cuda=cuda, + device_ids=device_ids, + lazy_load=lazy_load, + device_id=device_id, + specific_model_path=specific_model_path, + **kwargs, + ) + self._pooling = self.POSTPROCESSING_MAPPING[model_name].pooling + self._normalization = self.POSTPROCESSING_MAPPING[model_name].normalization + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + return cls.SUPPORTED_MODELS + + def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]: + return self._normalize(self._pool(output.model_output, output.attention_mask)) + + def _pool( + self, embeddings: NumpyArray, attention_mask: Optional[NDArray[np.int64]] = None + ) -> NumpyArray: + if self._pooling == PoolingType.CLS: + return embeddings[:, 0] + + if self._pooling == PoolingType.MEAN: + if attention_mask is None: + raise ValueError("attention_mask must be provided for mean pooling") + return mean_pooling(embeddings, attention_mask) + + if self._pooling == PoolingType.DISABLED: + return embeddings + + def _normalize(self, embeddings: NumpyArray) -> NumpyArray: + return normalize(embeddings) if self._normalization else embeddings + + @classmethod + def add_model( + cls, + model_description: DenseModelDescription, + pooling: PoolingType, + normalization: bool, + ) -> None: + cls.SUPPORTED_MODELS.append(model_description) + cls.POSTPROCESSING_MAPPING[model_description.model] = PostprocessingConfig( + pooling=pooling, normalization=normalization + ) diff --git a/fastembed/text/pooled_embedding.py b/fastembed/text/pooled_embedding.py index 33236635..1dc8c9f5 100644 --- a/fastembed/text/pooled_embedding.py +++ b/fastembed/text/pooled_embedding.py @@ -1,9 +1,11 @@ from typing import Any, Iterable, Type import numpy as np +from numpy.typing import NDArray from fastembed.common.types import NumpyArray from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.utils import mean_pooling from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker from fastembed.common.model_description import DenseModelDescription, ModelSource @@ -93,16 +95,10 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: return PooledEmbeddingWorker @classmethod - def mean_pooling(cls, model_output: NumpyArray, attention_mask: NumpyArray) -> NumpyArray: - token_embeddings = model_output.astype(np.float32) - attention_mask = attention_mask.astype(np.float32) - input_mask_expanded = np.expand_dims(attention_mask, axis=-1) - input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1])) - input_mask_expanded = input_mask_expanded.astype(np.float32) - sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) - sum_mask = np.sum(input_mask_expanded, axis=1) - pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9) - return pooled_embeddings + def mean_pooling( + cls, model_output: NumpyArray, attention_mask: NDArray[np.int64] + ) -> NumpyArray: + return mean_pooling(model_output, attention_mask) @classmethod def _list_supported_models(cls) -> list[DenseModelDescription]: diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 430d7d3a..c94a5883 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -4,12 +4,13 @@ from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.text.clip_embedding import CLIPOnnxEmbedding +from fastembed.text.custom_text_embedding import CustomTextEmbedding from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding from fastembed.text.pooled_embedding import PooledEmbedding from fastembed.text.multitask_embedding import JinaEmbeddingV3 from fastembed.text.onnx_embedding import OnnxTextEmbedding from fastembed.text.text_embedding_base import TextEmbeddingBase -from fastembed.common.model_description import DenseModelDescription +from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType class TextEmbedding(TextEmbeddingBase): @@ -19,6 +20,7 @@ class TextEmbedding(TextEmbeddingBase): PooledNormalizedEmbedding, PooledEmbedding, JinaEmbeddingV3, + CustomTextEmbedding, ] @classmethod @@ -37,6 +39,43 @@ def _list_supported_models(cls) -> list[DenseModelDescription]: result.extend(embedding._list_supported_models()) return result + @classmethod + def add_custom_model( + cls, + model: str, + pooling: PoolingType, + normalization: bool, + sources: ModelSource, + dim: int, + model_file: str = "onnx/model.onnx", + description: str = "", + license: str = "", + size_in_gb: float = 0.0, + additional_files: Optional[list[str]] = None, + ) -> None: + registered_models = cls._list_supported_models() + for registered_model in registered_models: + if model == registered_model.model: + raise ValueError( + f"Model {model} is already registered in TextEmbedding, if you still want to add this model, " + f"please use another model name" + ) + + CustomTextEmbedding.add_model( + DenseModelDescription( + model=model, + sources=sources, + dim=dim, + model_file=model_file, + description=description, + license=license, + size_in_GB=size_in_gb, + additional_files=additional_files or [], + ), + pooling=pooling, + normalization=normalization, + ) + def __init__( self, model_name: str = "BAAI/bge-small-en-v1.5", diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py new file mode 100644 index 00000000..f893c57d --- /dev/null +++ b/tests/test_custom_models.py @@ -0,0 +1,162 @@ +import itertools +import os +import numpy as np +import pytest + +from fastembed.common.model_description import PoolingType, ModelSource, DenseModelDescription +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.utils import normalize, mean_pooling +from fastembed.text.custom_text_embedding import CustomTextEmbedding, PostprocessingConfig +from fastembed.text.text_embedding import TextEmbedding +from tests.utils import delete_model_cache + + +@pytest.fixture(autouse=True) +def restore_custom_models_fixture(): + CustomTextEmbedding.SUPPORTED_MODELS = [] + yield + CustomTextEmbedding.SUPPORTED_MODELS = [] + + +def test_text_custom_model(): + is_ci = os.getenv("CI") + custom_model_name = "intfloat/multilingual-e5-small" + canonical_vector = np.array( + [3.1317e-02, 3.0939e-02, -3.5117e-02, -6.7274e-02, 8.5084e-02], dtype=np.float32 + ) + pooling = PoolingType.MEAN + normalization = True + dim = 384 + size_in_gb = 0.47 + source = ModelSource(hf=custom_model_name) + + TextEmbedding.add_custom_model( + custom_model_name, + pooling=pooling, + normalization=normalization, + sources=source, + dim=dim, + size_in_gb=size_in_gb, + ) + + assert CustomTextEmbedding.SUPPORTED_MODELS[0] == DenseModelDescription( + model=custom_model_name, + sources=source, + model_file="onnx/model.onnx", + description="", + license="", + size_in_GB=size_in_gb, + additional_files=[], + dim=dim, + tasks={}, + ) + assert CustomTextEmbedding.POSTPROCESSING_MAPPING[custom_model_name] == PostprocessingConfig( + pooling=pooling, normalization=normalization + ) + + model = TextEmbedding(custom_model_name) + docs = ["hello world", "flag embedding"] + embeddings = list(model.embed(docs)) + embeddings = np.stack(embeddings, axis=0) + assert embeddings.shape == (2, dim) + + assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3) + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_mock_add_custom_models(): + dim = 5 + size_in_gb = 0.1 + source = ModelSource(hf="artificial") + + num_tokens = 10 + dummy_pooled_embedding = np.random.random((1, dim)).astype(np.float32) + dummy_token_embedding = np.random.random((1, num_tokens, dim)).astype(np.float32) + dummy_attention_mask = np.ones((1, num_tokens)).astype(np.int64) + + dummy_token_output = OnnxOutputContext( + model_output=dummy_token_embedding, attention_mask=dummy_attention_mask + ) + dummy_pooled_output = OnnxOutputContext(model_output=dummy_pooled_embedding) + input_data = { + f"{PoolingType.MEAN.lower()}-normalized": dummy_token_output, + f"{PoolingType.MEAN.lower()}": dummy_token_output, + f"{PoolingType.CLS.lower()}-normalized": dummy_token_output, + f"{PoolingType.CLS.lower()}": dummy_token_output, + f"{PoolingType.DISABLED.lower()}-normalized": dummy_pooled_output, + f"{PoolingType.DISABLED.lower()}": dummy_pooled_output, + } + + expected_output = { + f"{PoolingType.MEAN.lower()}-normalized": normalize( + mean_pooling(dummy_token_embedding, dummy_attention_mask) + ).astype(np.float32), + f"{PoolingType.MEAN.lower()}": mean_pooling(dummy_token_embedding, dummy_attention_mask), + f"{PoolingType.CLS.lower()}-normalized": normalize(dummy_token_embedding[:, 0]).astype( + np.float32 + ), + f"{PoolingType.CLS.lower()}": dummy_token_embedding[:, 0], + f"{PoolingType.DISABLED.lower()}-normalized": normalize(dummy_pooled_embedding).astype( + np.float32 + ), + f"{PoolingType.DISABLED.lower()}": dummy_pooled_embedding, + } + + for pooling, normalization in itertools.product( + (PoolingType.MEAN, PoolingType.CLS, PoolingType.DISABLED), (True, False) + ): + model_name = f"{pooling.name.lower()}{'-normalized' if normalization else ''}" + TextEmbedding.add_custom_model( + model_name, + pooling=pooling, + normalization=normalization, + sources=source, + dim=dim, + size_in_gb=size_in_gb, + ) + + custom_text_embedding = CustomTextEmbedding( + model_name, + lazy_load=True, + specific_model_path="./", # disable model downloading and loading + ) + + post_processed_output = next( + iter(custom_text_embedding._post_process_onnx_output(input_data[model_name])) + ) + assert np.allclose(post_processed_output, expected_output[model_name], atol=1e-3) + + +def test_do_not_add_existing_model(): + existing_base_model = "sentence-transformers/all-MiniLM-L6-v2" + custom_model_name = "intfloat/multilingual-e5-small" + + with pytest.raises(ValueError, match=f"Model {existing_base_model} is already registered"): + TextEmbedding.add_custom_model( + existing_base_model, + pooling=PoolingType.MEAN, + normalization=True, + sources=ModelSource(hf=existing_base_model), + dim=384, + size_in_gb=0.47, + ) + + TextEmbedding.add_custom_model( + custom_model_name, + pooling=PoolingType.MEAN, + normalization=False, + sources=ModelSource(hf=existing_base_model), + dim=384, + size_in_gb=0.47, + ) + + with pytest.raises(ValueError, match=f"Model {custom_model_name} is already registered"): + TextEmbedding.add_custom_model( + custom_model_name, + pooling=PoolingType.MEAN, + normalization=True, + sources=ModelSource(hf=custom_model_name), + dim=384, + size_in_gb=0.47, + )