Skip to content

Commit d5da562

Browse files
authored
new: add embedding size property (#521)
* new: add embedding size property * fix: format exception message * new: replace embedding size property with get_embedding_size classmethod * fix: fix missed parts * chore: fix docstrings * new: add embedding_size property
1 parent 4e5575f commit d5da562

12 files changed

+258
-0
lines changed

fastembed/image/image_embedding.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,40 @@ def __init__(
7777
"Please check the supported models using `ImageEmbedding.list_supported_models()`"
7878
)
7979

80+
@property
81+
def embedding_size(self) -> int:
82+
"""Get the embedding size of the current model"""
83+
if self._embedding_size is None:
84+
self._embedding_size = self.get_embedding_size(self.model_name)
85+
return self._embedding_size
86+
87+
@classmethod
88+
def get_embedding_size(cls, model_name: str) -> int:
89+
"""Get the embedding size of the passed model
90+
91+
Args:
92+
model_name (str): The name of the model to get embedding size for.
93+
94+
Returns:
95+
int: The size of the embedding.
96+
97+
Raises:
98+
ValueError: If the model name is not found in the supported models.
99+
"""
100+
descriptions = cls._list_supported_models()
101+
embedding_size: Optional[int] = None
102+
for description in descriptions:
103+
if description.model.lower() == model_name.lower():
104+
embedding_size = description.dim
105+
break
106+
if embedding_size is None:
107+
model_names = [description.model for description in descriptions]
108+
raise ValueError(
109+
f"Embedding size for model {model_name} was None. "
110+
f"Available model names: {model_names}"
111+
)
112+
return embedding_size
113+
80114
def embed(
81115
self,
82116
images: Union[ImageInput, Iterable[ImageInput]],

fastembed/image/image_embedding_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
self.cache_dir = cache_dir
1919
self.threads = threads
2020
self._local_files_only = kwargs.pop("local_files_only", False)
21+
self._embedding_size: Optional[int] = None
2122

2223
def embed(
2324
self,
@@ -42,3 +43,13 @@ def embed(
4243
Iterable[NdArray]: The embeddings.
4344
"""
4445
raise NotImplementedError()
46+
47+
@classmethod
48+
def get_embedding_size(cls, model_name: str) -> int:
49+
"""Returns embedding size of the chosen model."""
50+
raise NotImplementedError("Subclasses must implement this method")
51+
52+
@property
53+
def embedding_size(self) -> int:
54+
"""Returns embedding size for the current model"""
55+
raise NotImplementedError("Subclasses must implement this method")

fastembed/late_interaction/late_interaction_embedding_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
self.cache_dir = cache_dir
1818
self.threads = threads
1919
self._local_files_only = kwargs.pop("local_files_only", False)
20+
self._embedding_size: Optional[int] = None
2021

2122
def embed(
2223
self,
@@ -58,3 +59,13 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterab
5859
yield from self.embed([query], **kwargs)
5960
else:
6061
yield from self.embed(query, **kwargs)
62+
63+
@classmethod
64+
def get_embedding_size(cls, model_name: str) -> int:
65+
"""Returns embedding size of the chosen model."""
66+
raise NotImplementedError("Subclasses must implement this method")
67+
68+
@property
69+
def embedding_size(self) -> int:
70+
"""Returns embedding size for the current model"""
71+
raise NotImplementedError("Subclasses must implement this method")

fastembed/late_interaction/late_interaction_text_embedding.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,40 @@ def __init__(
8080
"Please check the supported models using `LateInteractionTextEmbedding.list_supported_models()`"
8181
)
8282

83+
@property
84+
def embedding_size(self) -> int:
85+
"""Get the embedding size of the current model"""
86+
if self._embedding_size is None:
87+
self._embedding_size = self.get_embedding_size(self.model_name)
88+
return self._embedding_size
89+
90+
@classmethod
91+
def get_embedding_size(cls, model_name: str) -> int:
92+
"""Get the embedding size of the passed model
93+
94+
Args:
95+
model_name (str): The name of the model to get embedding size for.
96+
97+
Returns:
98+
int: The size of the embedding.
99+
100+
Raises:
101+
ValueError: If the model name is not found in the supported models.
102+
"""
103+
descriptions = cls._list_supported_models()
104+
embedding_size: Optional[int] = None
105+
for description in descriptions:
106+
if description.model.lower() == model_name.lower():
107+
embedding_size = description.dim
108+
break
109+
if embedding_size is None:
110+
model_names = [description.model for description in descriptions]
111+
raise ValueError(
112+
f"Embedding size for model {model_name} was None. "
113+
f"Available model names: {model_names}"
114+
)
115+
return embedding_size
116+
83117
def embed(
84118
self,
85119
documents: Union[str, Iterable[str]],

fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,40 @@ def __init__(
8383
"Please check the supported models using `LateInteractionMultimodalEmbedding.list_supported_models()`"
8484
)
8585

86+
@property
87+
def embedding_size(self) -> int:
88+
"""Get the embedding size of the current model"""
89+
if self._embedding_size is None:
90+
self._embedding_size = self.get_embedding_size(self.model_name)
91+
return self._embedding_size
92+
93+
@classmethod
94+
def get_embedding_size(cls, model_name: str) -> int:
95+
"""Get the embedding size of the passed model
96+
97+
Args:
98+
model_name (str): The name of the model to get embedding size for.
99+
100+
Returns:
101+
int: The size of the embedding.
102+
103+
Raises:
104+
ValueError: If the model name is not found in the supported models.
105+
"""
106+
descriptions = cls._list_supported_models()
107+
embedding_size: Optional[int] = None
108+
for description in descriptions:
109+
if description.model.lower() == model_name.lower():
110+
embedding_size = description.dim
111+
break
112+
if embedding_size is None:
113+
model_names = [description.model for description in descriptions]
114+
raise ValueError(
115+
f"Embedding size for model {model_name} was None. "
116+
f"Available model names: {model_names}"
117+
)
118+
return embedding_size
119+
86120
def embed_text(
87121
self,
88122
documents: Union[str, Iterable[str]],

fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
self.cache_dir = cache_dir
2020
self.threads = threads
2121
self._local_files_only = kwargs.pop("local_files_only", False)
22+
self._embedding_size: Optional[int] = None
2223

2324
def embed_text(
2425
self,
@@ -65,3 +66,13 @@ def embed_image(
6566
List of embeddings, one per image
6667
"""
6768
raise NotImplementedError()
69+
70+
@classmethod
71+
def get_embedding_size(cls, model_name: str) -> int:
72+
"""Returns embedding size of the chosen model."""
73+
raise NotImplementedError("Subclasses must implement this method")
74+
75+
@property
76+
def embedding_size(self) -> int:
77+
"""Returns embedding size for the current model"""
78+
raise NotImplementedError("Subclasses must implement this method")

fastembed/text/text_embedding.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,40 @@ def __init__(
128128
"Please check the supported models using `TextEmbedding.list_supported_models()`"
129129
)
130130

131+
@property
132+
def embedding_size(self) -> int:
133+
"""Get the embedding size of the current model"""
134+
if self._embedding_size is None:
135+
self._embedding_size = self.get_embedding_size(self.model_name)
136+
return self._embedding_size
137+
138+
@classmethod
139+
def get_embedding_size(cls, model_name: str) -> int:
140+
"""Get the embedding size of the passed model
141+
142+
Args:
143+
model_name (str): The name of the model to get embedding size for.
144+
145+
Returns:
146+
int: The size of the embedding.
147+
148+
Raises:
149+
ValueError: If the model name is not found in the supported models.
150+
"""
151+
descriptions = cls._list_supported_models()
152+
embedding_size: Optional[int] = None
153+
for description in descriptions:
154+
if description.model.lower() == model_name.lower():
155+
embedding_size = description.dim
156+
break
157+
if embedding_size is None:
158+
model_names = [description.model for description in descriptions]
159+
raise ValueError(
160+
f"Embedding size for model {model_name} was None. "
161+
f"Available model names: {model_names}"
162+
)
163+
return embedding_size
164+
131165
def embed(
132166
self,
133167
documents: Union[str, Iterable[str]],

fastembed/text/text_embedding_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
self.cache_dir = cache_dir
1818
self.threads = threads
1919
self._local_files_only = kwargs.pop("local_files_only", False)
20+
self._embedding_size: Optional[int] = None
2021

2122
def embed(
2223
self,
@@ -58,3 +59,13 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterab
5859
yield from self.embed([query], **kwargs)
5960
else:
6061
yield from self.embed(query, **kwargs)
62+
63+
@classmethod
64+
def get_embedding_size(cls, model_name: str) -> int:
65+
"""Returns embedding size of the passed model."""
66+
raise NotImplementedError("Subclasses must implement this method")
67+
68+
@property
69+
def embedding_size(self) -> int:
70+
"""Returns embedding size for the current model"""
71+
raise NotImplementedError("Subclasses must implement this method")

tests/test_image_onnx_embeddings.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,21 @@ def test_lazy_load(model_name: str) -> None:
127127
assert hasattr(model.model, "model")
128128
if is_ci:
129129
delete_model_cache(model.model._model_dir)
130+
131+
132+
def test_get_embedding_size() -> None:
133+
assert ImageEmbedding.get_embedding_size(model_name="Qdrant/clip-ViT-B-32-vision") == 512
134+
assert ImageEmbedding.get_embedding_size(model_name="Qdrant/clip-vit-b-32-vision") == 512
135+
136+
137+
def test_embedding_size() -> None:
138+
is_ci = os.getenv("CI")
139+
model_name = "Qdrant/clip-ViT-B-32-vision"
140+
model = ImageEmbedding(model_name=model_name, lazy_load=True)
141+
assert model.embedding_size == 512
142+
143+
model_name = "Qdrant/clip-vit-b-32-vision"
144+
model = ImageEmbedding(model_name=model_name, lazy_load=True)
145+
assert model.embedding_size == 512
146+
if is_ci:
147+
delete_model_cache(model.model._model_dir)

tests/test_late_interaction_embeddings.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,24 @@ def test_lazy_load(model_name: str):
254254

255255
if is_ci:
256256
delete_model_cache(model.model._model_dir)
257+
258+
259+
def test_get_embedding_size():
260+
model_name = "answerdotai/answerai-colbert-small-v1"
261+
assert LateInteractionTextEmbedding.get_embedding_size(model_name) == 96
262+
263+
model_name = "answerdotai/answerai-ColBERT-small-v1"
264+
assert LateInteractionTextEmbedding.get_embedding_size(model_name) == 96
265+
266+
267+
def test_embedding_size():
268+
is_ci = os.getenv("CI")
269+
model_name = "answerdotai/answerai-colbert-small-v1"
270+
model = LateInteractionTextEmbedding(model_name=model_name, lazy_load=True)
271+
assert model.embedding_size == 96
272+
273+
model_name = "answerdotai/answerai-ColBERT-small-v1"
274+
model = LateInteractionTextEmbedding(model_name=model_name, lazy_load=True)
275+
assert model.embedding_size == 96
276+
if is_ci:
277+
delete_model_cache(model.model._model_dir)

0 commit comments

Comments
 (0)