Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion fastembed/common/model_description.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Any


Expand Down Expand Up @@ -28,7 +29,7 @@ class BaseModelDescription:
@dataclass(frozen=True)
class DenseModelDescription(BaseModelDescription):
dim: Optional[int] = None
tasks: Optional[dict[str, Any]] = None
tasks: Optional[dict[str, Any]] = field(default_factory=dict)

def __post_init__(self) -> None:
assert self.dim is not None, "dim is required for dense model description"
Expand All @@ -38,3 +39,9 @@ def __post_init__(self) -> None:
class SparseModelDescription(BaseModelDescription):
requires_idf: Optional[bool] = None
vocab_size: Optional[int] = None


class PoolingType(str, Enum):
CLS = "CLS"
MEAN = "MEAN"
DISABLED = "DISABLED"
25 changes: 25 additions & 0 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,31 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
"""
raise NotImplementedError()

@classmethod
def add_custom_model(
cls,
*args: Any,
**kwargs: Any,
) -> None:
"""Add a custom model to the existing embedding classes based on the passed model descriptions

Model description dict should contain the fields same as in one of the model descriptions presented
in fastembed.common.model_description

E.g. for BaseModelDescription:
model: str
sources: ModelSource
model_file: str
description: str
license: str
size_in_GB: float
additional_files: list[str]

Returns:
None
"""
raise NotImplementedError()

@classmethod
def _list_supported_models(cls) -> list[T]:
raise NotImplementedError()
Expand Down
10 changes: 10 additions & 0 deletions fastembed/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Iterable, Optional, TypeVar

import numpy as np
from numpy.typing import NDArray

from fastembed.common.types import NumpyArray

Expand All @@ -22,6 +23,15 @@ def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e
return normalized_array


def mean_pooling(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray:
input_mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(np.int64)
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, input_array.shape[-1]))
sum_embeddings = np.sum(input_array * input_mask_expanded, axis=1)
sum_mask = np.sum(input_mask_expanded, axis=1)
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
return pooled_embeddings


def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]:
"""
>>> list(iter_batch([1,2,3,4,5], 3))
Expand Down
91 changes: 91 additions & 0 deletions fastembed/text/custom_text_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Optional, Sequence, Any, Iterable

from dataclasses import dataclass

import numpy as np
from numpy.typing import NDArray

from fastembed.common import OnnxProvider
from fastembed.common.model_description import (
PoolingType,
DenseModelDescription,
)
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.types import NumpyArray
from fastembed.common.utils import normalize, mean_pooling
from fastembed.text.onnx_embedding import OnnxTextEmbedding


@dataclass(frozen=True)
class PostprocessingConfig:
pooling: PoolingType
normalization: bool


class CustomTextEmbedding(OnnxTextEmbedding):
SUPPORTED_MODELS: list[DenseModelDescription] = []
POSTPROCESSING_MAPPING: dict[str, PostprocessingConfig] = {}

def __init__(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
specific_model_path: Optional[str] = None,
**kwargs: Any,
):
super().__init__(
model_name=model_name,
cache_dir=cache_dir,
threads=threads,
providers=providers,
cuda=cuda,
device_ids=device_ids,
lazy_load=lazy_load,
device_id=device_id,
specific_model_path=specific_model_path,
**kwargs,
)
self._pooling = self.POSTPROCESSING_MAPPING[model_name].pooling
self._normalization = self.POSTPROCESSING_MAPPING[model_name].normalization

@classmethod
def _list_supported_models(cls) -> list[DenseModelDescription]:
return cls.SUPPORTED_MODELS

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
return self._normalize(self._pool(output.model_output, output.attention_mask))

def _pool(
self, embeddings: NumpyArray, attention_mask: Optional[NDArray[np.int64]] = None
) -> NumpyArray:
if self._pooling == PoolingType.CLS:
return embeddings[:, 0]

if self._pooling == PoolingType.MEAN:
if attention_mask is None:
raise ValueError("attention_mask must be provided for mean pooling")
return mean_pooling(embeddings, attention_mask)

if self._pooling == PoolingType.DISABLED:
return embeddings

def _normalize(self, embeddings: NumpyArray) -> NumpyArray:
return normalize(embeddings) if self._normalization else embeddings

@classmethod
def add_model(
cls,
model_description: DenseModelDescription,
pooling: PoolingType,
normalization: bool,
) -> None:
cls.SUPPORTED_MODELS.append(model_description)
cls.POSTPROCESSING_MAPPING[model_description.model] = PostprocessingConfig(
pooling=pooling, normalization=normalization
)
16 changes: 6 additions & 10 deletions fastembed/text/pooled_embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Iterable, Type

import numpy as np
from numpy.typing import NDArray

from fastembed.common.types import NumpyArray
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import mean_pooling
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
from fastembed.common.model_description import DenseModelDescription, ModelSource

Expand Down Expand Up @@ -93,16 +95,10 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
return PooledEmbeddingWorker

@classmethod
def mean_pooling(cls, model_output: NumpyArray, attention_mask: NumpyArray) -> NumpyArray:
token_embeddings = model_output.astype(np.float32)
attention_mask = attention_mask.astype(np.float32)
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1]))
input_mask_expanded = input_mask_expanded.astype(np.float32)
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
sum_mask = np.sum(input_mask_expanded, axis=1)
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
return pooled_embeddings
def mean_pooling(
cls, model_output: NumpyArray, attention_mask: NDArray[np.int64]
) -> NumpyArray:
return mean_pooling(model_output, attention_mask)

@classmethod
def _list_supported_models(cls) -> list[DenseModelDescription]:
Expand Down
41 changes: 40 additions & 1 deletion fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from fastembed.common.types import NumpyArray, OnnxProvider
from fastembed.text.clip_embedding import CLIPOnnxEmbedding
from fastembed.text.custom_text_embedding import CustomTextEmbedding
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.pooled_embedding import PooledEmbedding
from fastembed.text.multitask_embedding import JinaEmbeddingV3
from fastembed.text.onnx_embedding import OnnxTextEmbedding
from fastembed.text.text_embedding_base import TextEmbeddingBase
from fastembed.common.model_description import DenseModelDescription
from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType


class TextEmbedding(TextEmbeddingBase):
Expand All @@ -19,6 +20,7 @@ class TextEmbedding(TextEmbeddingBase):
PooledNormalizedEmbedding,
PooledEmbedding,
JinaEmbeddingV3,
CustomTextEmbedding,
]

@classmethod
Expand All @@ -37,6 +39,43 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
result.extend(embedding._list_supported_models())
return result

@classmethod
def add_custom_model(
cls,
model: str,
pooling: PoolingType,
normalization: bool,
sources: ModelSource,
dim: int,
model_file: str = "onnx/model.onnx",
description: str = "",
license: str = "",
size_in_gb: float = 0.0,
additional_files: Optional[list[str]] = None,
) -> None:
registered_models = cls._list_supported_models()
for registered_model in registered_models:
if model == registered_model.model:
raise ValueError(
f"Model {model} is already registered in TextEmbedding, if you still want to add this model, "
f"please use another model name"
)

CustomTextEmbedding.add_model(
DenseModelDescription(
model=model,
sources=sources,
dim=dim,
model_file=model_file,
description=description,
license=license,
size_in_GB=size_in_gb,
additional_files=additional_files or [],
),
pooling=pooling,
normalization=normalization,
)

def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
Expand Down
Loading