44
55from fastembed .common .types import NumpyArray , OnnxProvider
66from fastembed .text .clip_embedding import CLIPOnnxEmbedding
7+ from fastembed .text .custom_text_embedding import CustomTextEmbedding
78from fastembed .text .pooled_normalized_embedding import PooledNormalizedEmbedding
89from fastembed .text .pooled_embedding import PooledEmbedding
910from fastembed .text .multitask_embedding import JinaEmbeddingV3
@@ -19,6 +20,7 @@ class TextEmbedding(TextEmbeddingBase):
1920 PooledNormalizedEmbedding ,
2021 PooledEmbedding ,
2122 JinaEmbeddingV3 ,
23+ CustomTextEmbedding ,
2224 ]
2325
2426 @classmethod
@@ -50,7 +52,6 @@ def add_custom_model(
5052 license : str = "" ,
5153 size_in_gb : float = 0.0 ,
5254 additional_files : Optional [list [str ]] = None ,
53- tasks : Optional [dict [str , Any ]] = None ,
5455 ) -> None :
5556 registered_models = cls ._list_supported_models ()
5657 for registered_model in registered_models :
@@ -60,56 +61,20 @@ def add_custom_model(
6061 f"please use another model name"
6162 )
6263
63- if tasks :
64- if pooling == PoolingType .MEAN and normalization :
65- JinaEmbeddingV3 .add_custom_model (
66- model = model ,
67- sources = sources ,
68- dim = dim ,
69- model_file = model_file ,
70- description = description ,
71- license = license ,
72- size_in_gb = size_in_gb ,
73- additional_files = additional_files ,
74- tasks = tasks ,
75- )
76- return None
77- else :
78- raise ValueError (
79- "Multitask models supported only with pooling=Pooling.MEAN and normalization=True, current values:"
80- f"pooling={ pooling } , normalization={ normalization } , tasks: { tasks } "
81- )
82-
83- embedding_cls : Type [OnnxTextEmbedding ]
84- if pooling == PoolingType .MEAN and normalization :
85- embedding_cls = PooledNormalizedEmbedding
86- elif pooling == PoolingType .MEAN and not normalization :
87- embedding_cls = PooledEmbedding
88- elif (pooling == PoolingType .CLS or PoolingType .DISABLED ) and normalization :
89- embedding_cls = OnnxTextEmbedding
90- elif pooling == PoolingType .DISABLED and not normalization :
91- embedding_cls = CLIPOnnxEmbedding
92- else :
93- raise ValueError (
94- "Only the following combinations of pooling and normalization are currently supported:"
95- "pooling=Pooling.MEAN + normalization=True;\n "
96- "pooling=Pooling.MEAN + normalization=False;\n "
97- "pooling=Pooling.CLS + normalization=True;\n "
98- "pooling=Pooling.DISABLED + normalization=False;\n "
99- )
100-
101- embedding_cls .add_custom_model (
102- model = model ,
103- sources = sources ,
104- dim = dim ,
105- model_file = model_file ,
106- description = description ,
107- license = license ,
108- size_in_gb = size_in_gb ,
109- additional_files = additional_files ,
110- tasks = tasks ,
64+ CustomTextEmbedding .add_model (
65+ DenseModelDescription (
66+ model = model ,
67+ sources = sources ,
68+ dim = dim ,
69+ model_file = model_file ,
70+ description = description ,
71+ license = license ,
72+ size_in_GB = size_in_gb ,
73+ additional_files = additional_files or [],
74+ ),
75+ pooling = pooling ,
76+ normalization = normalization ,
11177 )
112- return None
11378
11479 def __init__ (
11580 self ,
0 commit comments