Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions fastembed/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Iterable, Optional, TypeVar

import numpy as np
from numpy.typing import NDArray

from fastembed.common.types import NumpyArray

Expand All @@ -22,6 +23,15 @@ def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e
return normalized_array


def mean_pooling(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray:
input_mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(np.int64)
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, input_array.shape[-1]))
sum_embeddings = np.sum(input_array * input_mask_expanded, axis=1)
sum_mask = np.sum(input_mask_expanded, axis=1)
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
return pooled_embeddings


def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]:
"""
>>> list(iter_batch([1,2,3,4,5], 3))
Expand Down
4 changes: 1 addition & 3 deletions fastembed/text/clip_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@


class CLIPOnnxEmbedding(OnnxTextEmbedding):
CUSTOM_MODELS: list[DenseModelDescription] = []

@classmethod
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
return CLIPEmbeddingWorker
Expand All @@ -35,7 +33,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
Returns:
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
"""
return supported_clip_models + cls.CUSTOM_MODELS
return supported_clip_models

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
return output.model_output
Expand Down
91 changes: 91 additions & 0 deletions fastembed/text/custom_text_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Optional, Sequence, Any, Iterable

from dataclasses import dataclass

import numpy as np
from numpy.typing import NDArray

from fastembed.common import OnnxProvider
from fastembed.common.model_description import (
PoolingType,
DenseModelDescription,
)
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.types import NumpyArray
from fastembed.common.utils import normalize, mean_pooling
from fastembed.text.onnx_embedding import OnnxTextEmbedding


@dataclass(frozen=True)
class PostprocessingConfig:
pooling: PoolingType
normalization: bool


class CustomTextEmbedding(OnnxTextEmbedding):
SUPPORTED_MODELS: list[DenseModelDescription] = []
POSTPROCESSING_MAPPING: dict[str, PostprocessingConfig] = {}

def __init__(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
specific_model_path: Optional[str] = None,
**kwargs: Any,
):
super().__init__(
model_name=model_name,
cache_dir=cache_dir,
threads=threads,
providers=providers,
cuda=cuda,
device_ids=device_ids,
lazy_load=lazy_load,
device_id=device_id,
specific_model_path=specific_model_path,
**kwargs,
)
self._pooling = self.POSTPROCESSING_MAPPING[model_name].pooling
self._normalization = self.POSTPROCESSING_MAPPING[model_name].normalization

@classmethod
def _list_supported_models(cls) -> list[DenseModelDescription]:
return cls.SUPPORTED_MODELS

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
return self._normalize(self._pool(output.model_output, output.attention_mask))

def _pool(
self, embeddings: NumpyArray, attention_mask: Optional[NDArray[np.int64]] = None
) -> NumpyArray:
if self._pooling == PoolingType.CLS:
return embeddings[:, 0]

if self._pooling == PoolingType.MEAN:
if attention_mask is None:
raise ValueError("attention_mask must be provided for mean pooling")
return mean_pooling(embeddings, attention_mask)

if self._pooling == PoolingType.DISABLED:
return embeddings

def _normalize(self, embeddings: NumpyArray) -> NumpyArray:
return normalize(embeddings) if self._normalization else embeddings

@classmethod
def add_model(
cls,
model_description: DenseModelDescription,
pooling: PoolingType,
normalization: bool,
) -> None:
cls.SUPPORTED_MODELS.append(model_description)
cls.POSTPROCESSING_MAPPING[model_description.model] = PostprocessingConfig(
pooling=pooling, normalization=normalization
)
3 changes: 1 addition & 2 deletions fastembed/text/multitask_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class Task(int, Enum):
class JinaEmbeddingV3(PooledNormalizedEmbedding):
PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
QUERY_TASK = Task.RETRIEVAL_QUERY
CUSTOM_MODELS: list[DenseModelDescription] = []

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
Expand All @@ -55,7 +54,7 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:

@classmethod
def _list_supported_models(cls) -> list[DenseModelDescription]:
return supported_multitask_models + cls.CUSTOM_MODELS
return supported_multitask_models

def _preprocess_onnx_input(
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
Expand Down
31 changes: 1 addition & 30 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,35 +195,6 @@
class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[NumpyArray]):
"""Implementation of the Flag Embedding model."""

CUSTOM_MODELS: list[DenseModelDescription] = []

@classmethod
def add_custom_model(
cls,
model: str,
sources: ModelSource,
model_file: str,
dim: int,
description: str,
license: str,
size_in_gb: float,
additional_files: Optional[list[str]] = None,
tasks: Optional[dict[str, Any]] = None,
) -> None:
cls.CUSTOM_MODELS.append(
DenseModelDescription(
model=model,
sources=sources,
dim=dim,
model_file=model_file,
description=description,
license=license,
size_in_GB=size_in_gb,
additional_files=additional_files if additional_files else [],
tasks=tasks if tasks else {},
)
)

@classmethod
def _list_supported_models(cls) -> list[DenseModelDescription]:
"""
Expand All @@ -232,7 +203,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
Returns:
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
"""
return supported_onnx_models + cls.CUSTOM_MODELS
return supported_onnx_models

def __init__(
self,
Expand Down
20 changes: 7 additions & 13 deletions fastembed/text/pooled_embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Iterable, Type

import numpy as np
from numpy.typing import NDArray

from fastembed.common.types import NumpyArray
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import mean_pooling
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
from fastembed.common.model_description import DenseModelDescription, ModelSource

Expand Down Expand Up @@ -88,23 +90,15 @@


class PooledEmbedding(OnnxTextEmbedding):
CUSTOM_MODELS: list[DenseModelDescription] = []

@classmethod
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
return PooledEmbeddingWorker

@classmethod
def mean_pooling(cls, model_output: NumpyArray, attention_mask: NumpyArray) -> NumpyArray:
token_embeddings = model_output.astype(np.float32)
attention_mask = attention_mask.astype(np.float32)
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1]))
input_mask_expanded = input_mask_expanded.astype(np.float32)
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
sum_mask = np.sum(input_mask_expanded, axis=1)
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
return pooled_embeddings
def mean_pooling(
cls, model_output: NumpyArray, attention_mask: NDArray[np.int64]
) -> NumpyArray:
return mean_pooling(model_output, attention_mask)

@classmethod
def _list_supported_models(cls) -> list[DenseModelDescription]:
Expand All @@ -113,7 +107,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
Returns:
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
"""
return supported_pooled_models + cls.CUSTOM_MODELS
return supported_pooled_models

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
if output.attention_mask is None:
Expand Down
4 changes: 1 addition & 3 deletions fastembed/text/pooled_normalized_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@


class PooledNormalizedEmbedding(PooledEmbedding):
CUSTOM_MODELS: list[DenseModelDescription] = []

@classmethod
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
return PooledNormalizedEmbeddingWorker
Expand All @@ -126,7 +124,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
Returns:
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
"""
return supported_pooled_normalized_models + cls.CUSTOM_MODELS
return supported_pooled_normalized_models

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
if output.attention_mask is None:
Expand Down
65 changes: 15 additions & 50 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fastembed.common.types import NumpyArray, OnnxProvider
from fastembed.text.clip_embedding import CLIPOnnxEmbedding
from fastembed.text.custom_text_embedding import CustomTextEmbedding
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.pooled_embedding import PooledEmbedding
from fastembed.text.multitask_embedding import JinaEmbeddingV3
Expand All @@ -19,6 +20,7 @@ class TextEmbedding(TextEmbeddingBase):
PooledNormalizedEmbedding,
PooledEmbedding,
JinaEmbeddingV3,
CustomTextEmbedding,
]

@classmethod
Expand Down Expand Up @@ -50,7 +52,6 @@ def add_custom_model(
license: str = "",
size_in_gb: float = 0.0,
additional_files: Optional[list[str]] = None,
tasks: Optional[dict[str, Any]] = None,
) -> None:
registered_models = cls._list_supported_models()
for registered_model in registered_models:
Expand All @@ -60,56 +61,20 @@ def add_custom_model(
f"please use another model name"
)

if tasks:
if pooling == PoolingType.MEAN and normalization:
JinaEmbeddingV3.add_custom_model(
model=model,
sources=sources,
dim=dim,
model_file=model_file,
description=description,
license=license,
size_in_gb=size_in_gb,
additional_files=additional_files,
tasks=tasks,
)
return None
else:
raise ValueError(
"Multitask models supported only with pooling=Pooling.MEAN and normalization=True, current values:"
f"pooling={pooling}, normalization={normalization}, tasks: {tasks}"
)

embedding_cls: Type[OnnxTextEmbedding]
if pooling == PoolingType.MEAN and normalization:
embedding_cls = PooledNormalizedEmbedding
elif pooling == PoolingType.MEAN and not normalization:
embedding_cls = PooledEmbedding
elif (pooling == PoolingType.CLS or PoolingType.DISABLED) and normalization:
embedding_cls = OnnxTextEmbedding
elif pooling == PoolingType.DISABLED and not normalization:
embedding_cls = CLIPOnnxEmbedding
else:
raise ValueError(
"Only the following combinations of pooling and normalization are currently supported:"
"pooling=Pooling.MEAN + normalization=True;\n"
"pooling=Pooling.MEAN + normalization=False;\n"
"pooling=Pooling.CLS + normalization=True;\n"
"pooling=Pooling.DISABLED + normalization=False;\n"
)

embedding_cls.add_custom_model(
model=model,
sources=sources,
dim=dim,
model_file=model_file,
description=description,
license=license,
size_in_gb=size_in_gb,
additional_files=additional_files,
tasks=tasks,
CustomTextEmbedding.add_model(
DenseModelDescription(
model=model,
sources=sources,
dim=dim,
model_file=model_file,
description=description,
license=license,
size_in_GB=size_in_gb,
additional_files=additional_files or [],
),
pooling=pooling,
normalization=normalization,
)
return None

def __init__(
self,
Expand Down
Loading