Skip to content

Commit 2d254c0

Browse files
committed
refactor: refactor custom models
1 parent 1d3510d commit 2d254c0

File tree

9 files changed

+185
-236
lines changed

9 files changed

+185
-236
lines changed

fastembed/common/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e
2222
return normalized_array
2323

2424

25+
def mean_pooling(input_array: NumpyArray, attention_mask: NumpyArray) -> NumpyArray:
26+
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
27+
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, input_array.shape[-1]))
28+
input_mask_expanded = input_mask_expanded.astype(np.float32)
29+
sum_embeddings = np.sum(input_array * input_mask_expanded, axis=1)
30+
sum_mask = np.sum(input_mask_expanded, axis=1)
31+
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
32+
return pooled_embeddings
33+
34+
2535
def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]:
2636
"""
2737
>>> list(iter_batch([1,2,3,4,5], 3))

fastembed/text/clip_embedding.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222

2323

2424
class CLIPOnnxEmbedding(OnnxTextEmbedding):
25-
CUSTOM_MODELS: list[DenseModelDescription] = []
26-
2725
@classmethod
2826
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
2927
return CLIPEmbeddingWorker
@@ -35,7 +33,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
3533
Returns:
3634
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
3735
"""
38-
return supported_clip_models + cls.CUSTOM_MODELS
36+
return supported_clip_models
3937

4038
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
4139
return output.model_output
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from typing import Optional, Sequence, Any, Iterable
2+
3+
from dataclasses import dataclass
4+
from fastembed.common import OnnxProvider
5+
from fastembed.common.model_description import (
6+
PoolingType,
7+
DenseModelDescription,
8+
)
9+
from fastembed.common.onnx_model import OnnxOutputContext
10+
from fastembed.common.types import NumpyArray
11+
from fastembed.common.utils import normalize, mean_pooling
12+
from fastembed.text.onnx_embedding import OnnxTextEmbedding
13+
14+
15+
@dataclass(frozen=True)
16+
class PostprocessingConfig:
17+
pooling: PoolingType
18+
normalization: bool
19+
20+
21+
class CustomTextEmbedding(OnnxTextEmbedding):
22+
SUPPORTED_MODELS: list[DenseModelDescription] = []
23+
POSTPROCESSING_MAPPING: dict[str, PostprocessingConfig] = {}
24+
25+
def __init__(
26+
self,
27+
model_name: str,
28+
cache_dir: Optional[str] = None,
29+
threads: Optional[int] = None,
30+
providers: Optional[Sequence[OnnxProvider]] = None,
31+
cuda: bool = False,
32+
device_ids: Optional[list[int]] = None,
33+
lazy_load: bool = False,
34+
device_id: Optional[int] = None,
35+
specific_model_path: Optional[str] = None,
36+
**kwargs: Any,
37+
):
38+
super().__init__(
39+
model_name=model_name,
40+
cache_dir=cache_dir,
41+
threads=threads,
42+
providers=providers,
43+
cuda=cuda,
44+
device_ids=device_ids,
45+
lazy_load=lazy_load,
46+
device_id=device_id,
47+
specific_model_path=specific_model_path,
48+
**kwargs,
49+
)
50+
self._pooling = self.POSTPROCESSING_MAPPING[model_name].pooling
51+
self._normalization = self.POSTPROCESSING_MAPPING[model_name].normalization
52+
53+
@classmethod
54+
def _list_supported_models(cls) -> list[DenseModelDescription]:
55+
return cls.SUPPORTED_MODELS
56+
57+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
58+
return self._normalize(self._pool(output.model_output, output.attention_mask))
59+
60+
def _pool(
61+
self, embeddings: NumpyArray, attention_mask: Optional[NumpyArray] = None
62+
) -> NumpyArray:
63+
if self._pooling == PoolingType.CLS:
64+
return embeddings[:, 0]
65+
66+
if self._pooling == PoolingType.MEAN:
67+
if attention_mask is None:
68+
raise ValueError("attention_mask must be provided for mean pooling")
69+
return mean_pooling(embeddings, attention_mask)
70+
71+
if self._pooling == PoolingType.DISABLED:
72+
return embeddings
73+
74+
def _normalize(self, embeddings: NumpyArray) -> NumpyArray:
75+
return normalize(embeddings) if self._normalization else embeddings
76+
77+
@classmethod
78+
def add_model(
79+
cls,
80+
model_description: DenseModelDescription,
81+
pooling: PoolingType,
82+
normalization: bool,
83+
) -> None:
84+
cls.SUPPORTED_MODELS.append(model_description)
85+
cls.POSTPROCESSING_MAPPING[model_description.model] = PostprocessingConfig(
86+
pooling=pooling, normalization=normalization
87+
)

fastembed/text/multitask_embedding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ class Task(int, Enum):
4343
class JinaEmbeddingV3(PooledNormalizedEmbedding):
4444
PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
4545
QUERY_TASK = Task.RETRIEVAL_QUERY
46-
CUSTOM_MODELS: list[DenseModelDescription] = []
4746

4847
def __init__(self, *args: Any, **kwargs: Any):
4948
super().__init__(*args, **kwargs)
@@ -55,7 +54,7 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
5554

5655
@classmethod
5756
def _list_supported_models(cls) -> list[DenseModelDescription]:
58-
return supported_multitask_models + cls.CUSTOM_MODELS
57+
return supported_multitask_models
5958

6059
def _preprocess_onnx_input(
6160
self, onnx_input: dict[str, NumpyArray], **kwargs: Any

fastembed/text/onnx_embedding.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -195,35 +195,6 @@
195195
class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[NumpyArray]):
196196
"""Implementation of the Flag Embedding model."""
197197

198-
CUSTOM_MODELS: list[DenseModelDescription] = []
199-
200-
@classmethod
201-
def add_custom_model(
202-
cls,
203-
model: str,
204-
sources: ModelSource,
205-
model_file: str,
206-
dim: int,
207-
description: str,
208-
license: str,
209-
size_in_gb: float,
210-
additional_files: Optional[list[str]] = None,
211-
tasks: Optional[dict[str, Any]] = None,
212-
) -> None:
213-
cls.CUSTOM_MODELS.append(
214-
DenseModelDescription(
215-
model=model,
216-
sources=sources,
217-
dim=dim,
218-
model_file=model_file,
219-
description=description,
220-
license=license,
221-
size_in_GB=size_in_gb,
222-
additional_files=additional_files if additional_files else [],
223-
tasks=tasks if tasks else {},
224-
)
225-
)
226-
227198
@classmethod
228199
def _list_supported_models(cls) -> list[DenseModelDescription]:
229200
"""
@@ -232,7 +203,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
232203
Returns:
233204
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
234205
"""
235-
return supported_onnx_models + cls.CUSTOM_MODELS
206+
return supported_onnx_models
236207

237208
def __init__(
238209
self,

fastembed/text/pooled_embedding.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from fastembed.common.types import NumpyArray
66
from fastembed.common.onnx_model import OnnxOutputContext
7+
from fastembed.common.utils import mean_pooling
78
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
89
from fastembed.common.model_description import DenseModelDescription, ModelSource
910

@@ -88,8 +89,6 @@
8889

8990

9091
class PooledEmbedding(OnnxTextEmbedding):
91-
CUSTOM_MODELS: list[DenseModelDescription] = []
92-
9392
@classmethod
9493
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
9594
return PooledEmbeddingWorker
@@ -98,13 +97,7 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
9897
def mean_pooling(cls, model_output: NumpyArray, attention_mask: NumpyArray) -> NumpyArray:
9998
token_embeddings = model_output.astype(np.float32)
10099
attention_mask = attention_mask.astype(np.float32)
101-
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
102-
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1]))
103-
input_mask_expanded = input_mask_expanded.astype(np.float32)
104-
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
105-
sum_mask = np.sum(input_mask_expanded, axis=1)
106-
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
107-
return pooled_embeddings
100+
return mean_pooling(token_embeddings, attention_mask)
108101

109102
@classmethod
110103
def _list_supported_models(cls) -> list[DenseModelDescription]:
@@ -113,7 +106,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
113106
Returns:
114107
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
115108
"""
116-
return supported_pooled_models + cls.CUSTOM_MODELS
109+
return supported_pooled_models
117110

118111
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
119112
if output.attention_mask is None:

fastembed/text/pooled_normalized_embedding.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@
113113

114114

115115
class PooledNormalizedEmbedding(PooledEmbedding):
116-
CUSTOM_MODELS: list[DenseModelDescription] = []
117-
118116
@classmethod
119117
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
120118
return PooledNormalizedEmbeddingWorker
@@ -126,7 +124,7 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
126124
Returns:
127125
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
128126
"""
129-
return supported_pooled_normalized_models + cls.CUSTOM_MODELS
127+
return supported_pooled_normalized_models
130128

131129
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
132130
if output.attention_mask is None:

fastembed/text/text_embedding.py

Lines changed: 15 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from fastembed.common.types import NumpyArray, OnnxProvider
66
from fastembed.text.clip_embedding import CLIPOnnxEmbedding
7+
from fastembed.text.custom_text_embedding import CustomTextEmbedding
78
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
89
from fastembed.text.pooled_embedding import PooledEmbedding
910
from fastembed.text.multitask_embedding import JinaEmbeddingV3
@@ -19,6 +20,7 @@ class TextEmbedding(TextEmbeddingBase):
1920
PooledNormalizedEmbedding,
2021
PooledEmbedding,
2122
JinaEmbeddingV3,
23+
CustomTextEmbedding,
2224
]
2325

2426
@classmethod
@@ -50,7 +52,6 @@ def add_custom_model(
5052
license: str = "",
5153
size_in_gb: float = 0.0,
5254
additional_files: Optional[list[str]] = None,
53-
tasks: Optional[dict[str, Any]] = None,
5455
) -> None:
5556
registered_models = cls._list_supported_models()
5657
for registered_model in registered_models:
@@ -60,56 +61,20 @@ def add_custom_model(
6061
f"please use another model name"
6162
)
6263

63-
if tasks:
64-
if pooling == PoolingType.MEAN and normalization:
65-
JinaEmbeddingV3.add_custom_model(
66-
model=model,
67-
sources=sources,
68-
dim=dim,
69-
model_file=model_file,
70-
description=description,
71-
license=license,
72-
size_in_gb=size_in_gb,
73-
additional_files=additional_files,
74-
tasks=tasks,
75-
)
76-
return None
77-
else:
78-
raise ValueError(
79-
"Multitask models supported only with pooling=Pooling.MEAN and normalization=True, current values:"
80-
f"pooling={pooling}, normalization={normalization}, tasks: {tasks}"
81-
)
82-
83-
embedding_cls: Type[OnnxTextEmbedding]
84-
if pooling == PoolingType.MEAN and normalization:
85-
embedding_cls = PooledNormalizedEmbedding
86-
elif pooling == PoolingType.MEAN and not normalization:
87-
embedding_cls = PooledEmbedding
88-
elif (pooling == PoolingType.CLS or PoolingType.DISABLED) and normalization:
89-
embedding_cls = OnnxTextEmbedding
90-
elif pooling == PoolingType.DISABLED and not normalization:
91-
embedding_cls = CLIPOnnxEmbedding
92-
else:
93-
raise ValueError(
94-
"Only the following combinations of pooling and normalization are currently supported:"
95-
"pooling=Pooling.MEAN + normalization=True;\n"
96-
"pooling=Pooling.MEAN + normalization=False;\n"
97-
"pooling=Pooling.CLS + normalization=True;\n"
98-
"pooling=Pooling.DISABLED + normalization=False;\n"
99-
)
100-
101-
embedding_cls.add_custom_model(
102-
model=model,
103-
sources=sources,
104-
dim=dim,
105-
model_file=model_file,
106-
description=description,
107-
license=license,
108-
size_in_gb=size_in_gb,
109-
additional_files=additional_files,
110-
tasks=tasks,
64+
CustomTextEmbedding.add_model(
65+
DenseModelDescription(
66+
model=model,
67+
sources=sources,
68+
dim=dim,
69+
model_file=model_file,
70+
description=description,
71+
license=license,
72+
size_in_GB=size_in_gb,
73+
additional_files=additional_files or [],
74+
),
75+
pooling=pooling,
76+
normalization=normalization,
11177
)
112-
return None
11378

11479
def __init__(
11580
self,

0 commit comments

Comments
 (0)