Skip to content

Commit 8f04f57

Browse files
committed
new: add gemma embed
1 parent 685fd9b commit 8f04f57

File tree

4 files changed

+131
-2
lines changed

4 files changed

+131
-2
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Any, Iterable, Type
2+
3+
4+
from fastembed.common.types import NumpyArray
5+
from fastembed.common.onnx_model import OnnxOutputContext
6+
from fastembed.common.utils import normalize
7+
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
8+
from fastembed.common.model_description import DenseModelDescription, ModelSource
9+
10+
11+
supported_builtin_pooling_normalized_models: list[DenseModelDescription] = [
12+
DenseModelDescription(
13+
model="google/embeddinggemma-300m",
14+
dim=768,
15+
description=(
16+
"Text embeddings, Unimodal (text), multilingual, 2048 input tokens truncation, "
17+
"Prefixes for queries/documents: `task: search result | query: {content}` for query, "
18+
"`title: {title | 'none'} | text: {content}` for documents, 2025 year."
19+
),
20+
license="apache-2.0",
21+
size_in_GB=1.24,
22+
sources=ModelSource(
23+
hf="onnx-community/embeddinggemma-300m-ONNX",
24+
),
25+
model_file="onnx/model.onnx",
26+
additional_files=["onnx/model.onnx_data"],
27+
),
28+
]
29+
30+
31+
class BuiltinPoolingNormalizedEmbedding(OnnxTextEmbedding):
32+
@classmethod
33+
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
34+
return BuiltinPoolingNormalizedEmbeddingWorker
35+
36+
@classmethod
37+
def _list_supported_models(cls) -> list[DenseModelDescription]:
38+
"""Lists the supported models.
39+
40+
Returns:
41+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
42+
"""
43+
return supported_builtin_pooling_normalized_models
44+
45+
def _post_process_onnx_output(
46+
self, output: OnnxOutputContext, **kwargs: Any
47+
) -> Iterable[NumpyArray]:
48+
return normalize(output.model_output)
49+
50+
def _run_model(
51+
self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None
52+
) -> NumpyArray:
53+
return self.model.run(onnx_output_names, onnx_input)[1] # type: ignore[union-attr]
54+
55+
56+
class BuiltinPoolingNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):
57+
def init_embedding(
58+
self,
59+
model_name: str,
60+
cache_dir: str,
61+
**kwargs: Any,
62+
) -> OnnxTextEmbedding:
63+
return BuiltinPoolingNormalizedEmbedding(
64+
model_name=model_name,
65+
cache_dir=cache_dir,
66+
threads=1,
67+
**kwargs,
68+
)

fastembed/text/onnx_text_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,21 @@ def onnx_embed(
9292
[np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64
9393
)
9494
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)
95+
model_output = self._run_model(
96+
onnx_input=onnx_input, onnx_output_names=self.ONNX_OUTPUT_NAMES
97+
)
9598

96-
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
9799
return OnnxOutputContext(
98-
model_output=model_output[0],
100+
model_output=model_output,
99101
attention_mask=onnx_input.get("attention_mask", attention_mask),
100102
input_ids=onnx_input.get("input_ids", input_ids),
101103
)
102104

105+
def _run_model(
106+
self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None
107+
) -> NumpyArray:
108+
return self.model.run(onnx_output_names, onnx_input)[0] # type: ignore[union-attr]
109+
103110
def _embed_documents(
104111
self,
105112
model_name: str,

fastembed/text/text_embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
99
from fastembed.text.pooled_embedding import PooledEmbedding
1010
from fastembed.text.multitask_embedding import JinaEmbeddingV3
11+
from fastembed.text.builtin_pooling_normalized_embedding import BuiltinPoolingNormalizedEmbedding
1112
from fastembed.text.onnx_embedding import OnnxTextEmbedding
1213
from fastembed.text.text_embedding_base import TextEmbeddingBase
1314
from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType
@@ -20,6 +21,7 @@ class TextEmbedding(TextEmbeddingBase):
2021
PooledNormalizedEmbedding,
2122
PooledEmbedding,
2223
JinaEmbeddingV3,
24+
BuiltinPoolingNormalizedEmbedding,
2325
CustomTextEmbedding,
2426
]
2527

tests/test_text_onnx_embeddings.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@
6868
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
6969
"thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]),
7070
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
71+
"google/embeddinggemma-300m": np.array(
72+
[-0.08181356, 0.0214127, 0.05120273, -0.03690156, -0.0254504]
73+
),
74+
}
75+
76+
77+
DOC_PREFIXES = {
78+
"google/embeddinggemma-300m": "title: none | text: ",
79+
}
80+
QUERY_PREFIXES = {
81+
"google/embeddinggemma-300m": "task: search result | query: ",
82+
}
83+
CANONICAL_QUERY_VECTOR_VALUES = {
84+
"google/embeddinggemma-300m": np.array(
85+
[-0.22990295, 0.03311195, 0.04290345, -0.03558498, -0.01399477]
86+
)
7187
}
7288

7389
MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"]
@@ -119,6 +135,9 @@ def test_embedding(model_cache, model_name: str) -> None:
119135

120136
with model_cache(model_desc.model) as model:
121137
docs = ["hello world", "flag embedding"]
138+
if model_desc.model in DOC_PREFIXES:
139+
docs = [DOC_PREFIXES[model_desc.model] + doc for doc in docs]
140+
122141
embeddings = list(model.embed(docs))
123142
embeddings = np.stack(embeddings, axis=0)
124143
assert embeddings.shape == (2, dim)
@@ -129,6 +148,39 @@ def test_embedding(model_cache, model_name: str) -> None:
129148
), model_desc.model
130149

131150

151+
def test_query_embedding(model_cache) -> None:
152+
is_ci = os.getenv("CI")
153+
is_mac = platform.system() == "Darwin"
154+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
155+
156+
for model_desc in TextEmbedding._list_supported_models():
157+
if model_desc.model in MULTI_TASK_MODELS or (
158+
is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q"
159+
):
160+
continue
161+
162+
if model_desc.model not in CANONICAL_QUERY_VECTOR_VALUES:
163+
continue
164+
165+
if not should_test_model(model_desc, "", is_ci, is_manual):
166+
continue
167+
168+
dim = model_desc.dim
169+
with model_cache(model_desc.model) as model:
170+
queries = ["hello world", "flag embedding"]
171+
if model_desc.model in QUERY_PREFIXES:
172+
queries = [QUERY_PREFIXES[model_desc.model] + query for query in queries]
173+
174+
embeddings = list(model.query_embed(queries))
175+
embeddings = np.stack(embeddings, axis=0)
176+
assert embeddings.shape == (2, dim)
177+
178+
canonical_vector = CANONICAL_QUERY_VECTOR_VALUES[model_desc.model]
179+
assert np.allclose(
180+
embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
181+
), model_desc.model
182+
183+
132184
@pytest.mark.parametrize("n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5")])
133185
def test_batch_embedding(model_cache, n_dims: int, model_name: str) -> None:
134186
with model_cache(model_name) as model:

0 commit comments

Comments
 (0)